fawazo commited on
Commit
82b7c62
·
verified ·
1 Parent(s): a0a68fc

Upload train_pentest_v4_memfix.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pentest_v4_memfix.py +163 -0
train_pentest_v4_memfix.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "datasets>=2.14.0",
9
+ # "bitsandbytes",
10
+ # ]
11
+ # ///
12
+
13
+ from datasets import load_dataset, concatenate_datasets
14
+
15
+ PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only.
16
+
17
+ Formats:
18
+ 1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
19
+ 2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."}
20
+ 3. {"action": "command", "cmd": "...", "reasoning": "..."}
21
+ 4. {"action": "complete", "summary": "..."}
22
+
23
+ Respond with ONLY valid JSON."""
24
+
25
+ def validate_messages(messages):
26
+ if not messages or not isinstance(messages, list) or len(messages) < 2:
27
+ return False
28
+ for msg in messages:
29
+ if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
30
+ return False
31
+ if not msg["content"] or not isinstance(msg["content"], str) or len(msg["content"].strip()) < 5:
32
+ return False
33
+ if msg["role"] not in ["system", "user", "assistant"]:
34
+ return False
35
+ return True
36
+
37
+ def load_trendyol():
38
+ print("Loading Trendyol...")
39
+ ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
40
+ print(f" Raw: {len(ds)}")
41
+
42
+ def convert(ex):
43
+ msgs = [
44
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
45
+ {"role": "user", "content": str(ex["user"]).strip()},
46
+ {"role": "assistant", "content": str(ex["assistant"]).strip()}
47
+ ]
48
+ return {"messages": msgs, "valid": validate_messages(msgs)}
49
+
50
+ ds = ds.map(convert, remove_columns=ds.column_names)
51
+ ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
52
+ print(f" Valid: {len(ds)}")
53
+ return ds
54
+
55
+ def load_fenrir():
56
+ print("Loading Fenrir...")
57
+ ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
58
+ print(f" Raw: {len(ds)}")
59
+
60
+ def convert(ex):
61
+ msgs = [
62
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
63
+ {"role": "user", "content": str(ex["user"]).strip()},
64
+ {"role": "assistant", "content": str(ex["assistant"]).strip()}
65
+ ]
66
+ return {"messages": msgs, "valid": validate_messages(msgs)}
67
+
68
+ ds = ds.map(convert, remove_columns=ds.column_names)
69
+ ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
70
+ print(f" Valid: {len(ds)}")
71
+ return ds
72
+
73
+ print("=" * 50)
74
+ print("LOADING DATASETS")
75
+ print("=" * 50)
76
+
77
+ all_ds = []
78
+ try:
79
+ all_ds.append(load_trendyol())
80
+ except Exception as e:
81
+ print(f"Trendyol error: {e}")
82
+
83
+ try:
84
+ all_ds.append(load_fenrir())
85
+ except Exception as e:
86
+ print(f"Fenrir error: {e}")
87
+
88
+ combined = concatenate_datasets(all_ds).shuffle(seed=42)
89
+ print(f"\nTotal: {len(combined)}")
90
+
91
+ split = combined.train_test_split(test_size=0.02, seed=42)
92
+ train_ds, eval_ds = split["train"], split["test"]
93
+ print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
94
+
95
+ print("\n" + "=" * 50)
96
+ print("TRAINING WITH MEMORY OPTIMIZATION")
97
+ print("=" * 50)
98
+
99
+ from peft import LoraConfig
100
+ from trl import SFTTrainer, SFTConfig
101
+
102
+ config = SFTConfig(
103
+ output_dir="qwen2.5-coder-3b-pentest",
104
+ push_to_hub=True,
105
+ hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
106
+ hub_strategy="every_save",
107
+
108
+ num_train_epochs=2,
109
+ # MEMORY FIX: batch=1, accumulation=16
110
+ per_device_train_batch_size=1,
111
+ gradient_accumulation_steps=16,
112
+ learning_rate=1e-4,
113
+ # MEMORY FIX: shorter sequences
114
+ max_length=1024,
115
+
116
+ # MEMORY FIX: all optimizations enabled
117
+ gradient_checkpointing=True,
118
+ bf16=True,
119
+ optim="adamw_8bit",
120
+
121
+ logging_steps=25,
122
+ save_strategy="steps",
123
+ save_steps=1000,
124
+ save_total_limit=2,
125
+
126
+ eval_strategy="steps",
127
+ eval_steps=1000,
128
+
129
+ warmup_ratio=0.03,
130
+ lr_scheduler_type="cosine",
131
+
132
+ report_to="trackio",
133
+ project="pentest-agent",
134
+ run_name="qwen-3b-cybersec-v4-memfix",
135
+ )
136
+
137
+ # Smaller LoRA for memory
138
+ peft_config = LoraConfig(
139
+ r=16,
140
+ lora_alpha=32,
141
+ lora_dropout=0.05,
142
+ bias="none",
143
+ task_type="CAUSAL_LM",
144
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
145
+ )
146
+
147
+ print("Loading model...")
148
+ trainer = SFTTrainer(
149
+ model="Qwen/Qwen2.5-Coder-3B",
150
+ train_dataset=train_ds,
151
+ eval_dataset=eval_ds,
152
+ args=config,
153
+ peft_config=peft_config,
154
+ )
155
+
156
+ print("Training...")
157
+ trainer.train()
158
+ trainer.push_to_hub()
159
+
160
+ print("\n" + "=" * 50)
161
+ print("COMPLETE!")
162
+ print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
163
+ print("=" * 50)