fawazo commited on
Commit
78085ca
·
verified ·
1 Parent(s): 82b7c62

Upload train_pentest_a100.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pentest_a100.py +156 -0
train_pentest_a100.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "datasets>=2.14.0",
9
+ # "bitsandbytes",
10
+ # ]
11
+ # ///
12
+
13
+ from datasets import load_dataset, concatenate_datasets
14
+
15
+ PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only.
16
+
17
+ Formats:
18
+ 1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
19
+ 2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."}
20
+ 3. {"action": "command", "cmd": "...", "reasoning": "..."}
21
+ 4. {"action": "complete", "summary": "..."}
22
+
23
+ Respond with ONLY valid JSON."""
24
+
25
+ def validate_messages(messages):
26
+ if not messages or not isinstance(messages, list) or len(messages) < 2:
27
+ return False
28
+ for msg in messages:
29
+ if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
30
+ return False
31
+ if not msg["content"] or not isinstance(msg["content"], str) or len(msg["content"].strip()) < 5:
32
+ return False
33
+ if msg["role"] not in ["system", "user", "assistant"]:
34
+ return False
35
+ return True
36
+
37
+ def load_trendyol():
38
+ print("Loading Trendyol...")
39
+ ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
40
+ print(f" Raw: {len(ds)}")
41
+ def convert(ex):
42
+ msgs = [
43
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
44
+ {"role": "user", "content": str(ex["user"]).strip()},
45
+ {"role": "assistant", "content": str(ex["assistant"]).strip()}
46
+ ]
47
+ return {"messages": msgs, "valid": validate_messages(msgs)}
48
+ ds = ds.map(convert, remove_columns=ds.column_names)
49
+ ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
50
+ print(f" Valid: {len(ds)}")
51
+ return ds
52
+
53
+ def load_fenrir():
54
+ print("Loading Fenrir...")
55
+ ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
56
+ print(f" Raw: {len(ds)}")
57
+ def convert(ex):
58
+ msgs = [
59
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
60
+ {"role": "user", "content": str(ex["user"]).strip()},
61
+ {"role": "assistant", "content": str(ex["assistant"]).strip()}
62
+ ]
63
+ return {"messages": msgs, "valid": validate_messages(msgs)}
64
+ ds = ds.map(convert, remove_columns=ds.column_names)
65
+ ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
66
+ print(f" Valid: {len(ds)}")
67
+ return ds
68
+
69
+ print("=" * 50)
70
+ print("LOADING DATASETS")
71
+ print("=" * 50)
72
+
73
+ all_ds = []
74
+ try:
75
+ all_ds.append(load_trendyol())
76
+ except Exception as e:
77
+ print(f"Trendyol error: {e}")
78
+ try:
79
+ all_ds.append(load_fenrir())
80
+ except Exception as e:
81
+ print(f"Fenrir error: {e}")
82
+
83
+ combined = concatenate_datasets(all_ds).shuffle(seed=42)
84
+ print(f"\nTotal: {len(combined)}")
85
+
86
+ split = combined.train_test_split(test_size=0.02, seed=42)
87
+ train_ds, eval_ds = split["train"], split["test"]
88
+ print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
89
+
90
+ print("\n" + "=" * 50)
91
+ print("TRAINING ON A100 (80GB)")
92
+ print("=" * 50)
93
+
94
+ from peft import LoraConfig
95
+ from trl import SFTTrainer, SFTConfig
96
+
97
+ # A100 can handle larger batches
98
+ config = SFTConfig(
99
+ output_dir="qwen2.5-coder-3b-pentest",
100
+ push_to_hub=True,
101
+ hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
102
+ hub_strategy="every_save",
103
+
104
+ num_train_epochs=2,
105
+ # A100 80GB - can do larger batches
106
+ per_device_train_batch_size=8,
107
+ gradient_accumulation_steps=2,
108
+ learning_rate=1e-4,
109
+ max_length=1536,
110
+
111
+ gradient_checkpointing=True,
112
+ bf16=True,
113
+
114
+ logging_steps=50,
115
+ # Save frequently for resume capability
116
+ save_strategy="steps",
117
+ save_steps=500,
118
+ save_total_limit=3,
119
+
120
+ eval_strategy="steps",
121
+ eval_steps=500,
122
+
123
+ warmup_ratio=0.03,
124
+ lr_scheduler_type="cosine",
125
+
126
+ report_to="trackio",
127
+ project="pentest-agent",
128
+ run_name="qwen-3b-cybersec-a100",
129
+ )
130
+
131
+ peft_config = LoraConfig(
132
+ r=32,
133
+ lora_alpha=64,
134
+ lora_dropout=0.05,
135
+ bias="none",
136
+ task_type="CAUSAL_LM",
137
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
138
+ )
139
+
140
+ print("Loading model...")
141
+ trainer = SFTTrainer(
142
+ model="Qwen/Qwen2.5-Coder-3B",
143
+ train_dataset=train_ds,
144
+ eval_dataset=eval_ds,
145
+ args=config,
146
+ peft_config=peft_config,
147
+ )
148
+
149
+ print("Training...")
150
+ trainer.train()
151
+ trainer.push_to_hub()
152
+
153
+ print("\n" + "=" * 50)
154
+ print("COMPLETE!")
155
+ print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
156
+ print("=" * 50)