fawazo commited on
Commit
8186f7a
·
verified ·
1 Parent(s): a71f7c1

Upload train_pentest_combined.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pentest_combined.py +208 -0
train_pentest_combined.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "datasets>=2.14.0",
9
+ # ]
10
+ # ///
11
+
12
+ import json
13
+ from datasets import load_dataset, concatenate_datasets, Dataset
14
+ from peft import LoraConfig
15
+ from trl import SFTTrainer, SFTConfig
16
+
17
+ # Custom system prompt for pentesting JSON output
18
+ PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI assistant. Your role is to analyze web application traffic (HTTP requests and responses) and identify security vulnerabilities.
19
+
20
+ When given web traffic data, you MUST respond with a valid JSON object in one of these formats:
21
+
22
+ 1. When you find a vulnerability:
23
+ {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE|LFI|XXE|CSRF|Auth_Bypass|Info_Disclosure", "severity": "critical|high|medium|low", "description": "detailed explanation", "evidence": "specific evidence from the request/response", "remediation": "how to fix"}}
24
+
25
+ 2. When you want to send a follow-up request to test further:
26
+ {"action": "request", "method": "GET|POST|PUT|DELETE", "path": "/endpoint", "headers": {}, "body": "if applicable", "reasoning": "why this test"}
27
+
28
+ 3. When you want to run a command (for authorized testing):
29
+ {"action": "command", "cmd": "command to execute", "reasoning": "why this command helps"}
30
+
31
+ 4. When analysis is complete with no findings:
32
+ {"action": "complete", "summary": "analysis summary", "tested": ["list of tests performed"]}
33
+
34
+ Always respond with ONLY the JSON object, no additional text. Analyze thoroughly for OWASP Top 10, MITRE ATT&CK techniques, and common web vulnerabilities."""
35
+
36
+ def load_and_convert_trendyol():
37
+ """Load Trendyol dataset and convert to ChatML format"""
38
+ print("Loading Trendyol dataset...")
39
+ ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
40
+ print(f" Loaded {len(ds)} examples")
41
+
42
+ def convert(example):
43
+ return {
44
+ "messages": [
45
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
46
+ {"role": "user", "content": example["user"]},
47
+ {"role": "assistant", "content": example["assistant"]}
48
+ ]
49
+ }
50
+
51
+ return ds.map(convert, remove_columns=ds.column_names)
52
+
53
+ def load_and_convert_fenrir():
54
+ """Load Fenrir dataset and convert to ChatML format"""
55
+ print("Loading Fenrir v2.0 dataset...")
56
+ ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
57
+ print(f" Loaded {len(ds)} examples")
58
+
59
+ def convert(example):
60
+ return {
61
+ "messages": [
62
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
63
+ {"role": "user", "content": example["user"]},
64
+ {"role": "assistant", "content": example["assistant"]}
65
+ ]
66
+ }
67
+
68
+ cols_to_remove = [c for c in ds.column_names if c != "messages"]
69
+ return ds.map(convert, remove_columns=cols_to_remove)
70
+
71
+ def load_pentest_agent():
72
+ """Load pentest-agent dataset (already in ChatML format)"""
73
+ print("Loading pentest-agent dataset...")
74
+ try:
75
+ ds = load_dataset(
76
+ "jason-oneal/pentest-agent-dataset",
77
+ data_files="chatml_train.jsonl",
78
+ split="train"
79
+ )
80
+ print(f" Loaded {len(ds)} examples")
81
+
82
+ # Update system prompt to our custom one
83
+ def update_system(example):
84
+ messages = example["messages"]
85
+ if messages and messages[0]["role"] == "system":
86
+ messages[0]["content"] = PENTEST_SYSTEM_PROMPT
87
+ else:
88
+ messages.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT})
89
+ return {"messages": messages}
90
+
91
+ return ds.map(update_system)
92
+ except Exception as e:
93
+ print(f" Warning: Could not load pentest-agent dataset: {e}")
94
+ return None
95
+
96
+ # Load all datasets
97
+ print("=" * 50)
98
+ print("LOADING AND COMBINING DATASETS")
99
+ print("=" * 50)
100
+
101
+ datasets_to_combine = []
102
+
103
+ # Load each dataset
104
+ trendyol_ds = load_and_convert_trendyol()
105
+ datasets_to_combine.append(trendyol_ds)
106
+
107
+ fenrir_ds = load_and_convert_fenrir()
108
+ datasets_to_combine.append(fenrir_ds)
109
+
110
+ pentest_ds = load_pentest_agent()
111
+ if pentest_ds is not None:
112
+ datasets_to_combine.append(pentest_ds)
113
+
114
+ # Combine all datasets
115
+ print("\nCombining datasets...")
116
+ combined_dataset = concatenate_datasets(datasets_to_combine)
117
+ print(f"Total combined examples: {len(combined_dataset)}")
118
+
119
+ # Shuffle the combined dataset
120
+ combined_dataset = combined_dataset.shuffle(seed=42)
121
+
122
+ # Create train/eval split
123
+ print("\nCreating train/eval split...")
124
+ split_dataset = combined_dataset.train_test_split(test_size=0.05, seed=42)
125
+ train_dataset = split_dataset["train"]
126
+ eval_dataset = split_dataset["test"]
127
+ print(f" Train: {len(train_dataset)} examples")
128
+ print(f" Eval: {len(eval_dataset)} examples")
129
+
130
+ # Training configuration for 3B model
131
+ print("\n" + "=" * 50)
132
+ print("CONFIGURING TRAINING")
133
+ print("=" * 50)
134
+
135
+ config = SFTConfig(
136
+ output_dir="qwen2.5-coder-3b-pentest",
137
+ push_to_hub=True,
138
+ hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
139
+ hub_strategy="every_save",
140
+ hub_private_repo=False,
141
+
142
+ # Training parameters optimized for 3B model
143
+ num_train_epochs=2,
144
+ per_device_train_batch_size=2,
145
+ gradient_accumulation_steps=8, # Effective batch size = 16
146
+ learning_rate=1e-4,
147
+ max_length=2048,
148
+
149
+ # Memory optimization
150
+ gradient_checkpointing=True,
151
+ bf16=True,
152
+
153
+ # Logging & checkpointing
154
+ logging_steps=50,
155
+ save_strategy="steps",
156
+ save_steps=500,
157
+ save_total_limit=3,
158
+
159
+ # Evaluation
160
+ eval_strategy="steps",
161
+ eval_steps=500,
162
+
163
+ # Optimization
164
+ warmup_ratio=0.05,
165
+ lr_scheduler_type="cosine",
166
+ weight_decay=0.01,
167
+
168
+ # Monitoring
169
+ report_to="trackio",
170
+ project="pentest-agent",
171
+ run_name="qwen2.5-coder-3b-combined-150k",
172
+ )
173
+
174
+ # LoRA configuration - higher rank for better adaptation
175
+ peft_config = LoraConfig(
176
+ r=32,
177
+ lora_alpha=64,
178
+ lora_dropout=0.05,
179
+ bias="none",
180
+ task_type="CAUSAL_LM",
181
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
182
+ )
183
+
184
+ # Initialize trainer
185
+ print("\nInitializing trainer with Qwen2.5-Coder-3B...")
186
+ trainer = SFTTrainer(
187
+ model="Qwen/Qwen2.5-Coder-3B",
188
+ train_dataset=train_dataset,
189
+ eval_dataset=eval_dataset,
190
+ args=config,
191
+ peft_config=peft_config,
192
+ )
193
+
194
+ # Train
195
+ print("\n" + "=" * 50)
196
+ print("STARTING TRAINING")
197
+ print("=" * 50)
198
+ trainer.train()
199
+
200
+ # Push final model
201
+ print("\nPushing final model to Hub...")
202
+ trainer.push_to_hub()
203
+
204
+ print("\n" + "=" * 50)
205
+ print("TRAINING COMPLETE!")
206
+ print("=" * 50)
207
+ print("Model saved to: https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
208
+ print("Trackio dashboard: https://huggingface.co/spaces/fawazo/trackio")