training-scripts / train_pentest_v3.py
fawazo's picture
Upload train_pentest_v3.py with huggingface_hub
a0a68fc verified
raw
history blame
6.72 kB
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "trackio",
# "datasets>=2.14.0",
# ]
# ///
import traceback
from datasets import load_dataset, concatenate_datasets
PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only.
Formats:
1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."}
3. {"action": "command", "cmd": "...", "reasoning": "..."}
4. {"action": "complete", "summary": "..."}
Respond with ONLY valid JSON."""
def validate_messages(messages):
"""Check if messages are valid for chat template"""
if not messages or not isinstance(messages, list):
return False
if len(messages) < 2:
return False
for msg in messages:
if not isinstance(msg, dict):
return False
if "role" not in msg or "content" not in msg:
return False
if not msg["content"] or not isinstance(msg["content"], str):
return False
if msg["role"] not in ["system", "user", "assistant"]:
return False
if len(msg["content"].strip()) < 5:
return False
return True
def load_trendyol():
print("Loading Trendyol...")
ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
print(f" Raw: {len(ds)}")
def convert(ex):
msgs = [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": str(ex["user"]).strip()},
{"role": "assistant", "content": str(ex["assistant"]).strip()}
]
return {"messages": msgs, "valid": validate_messages(msgs)}
ds = ds.map(convert, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"])
ds = ds.remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
def load_fenrir():
print("Loading Fenrir...")
ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
print(f" Raw: {len(ds)}")
def convert(ex):
msgs = [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": str(ex["user"]).strip()},
{"role": "assistant", "content": str(ex["assistant"]).strip()}
]
return {"messages": msgs, "valid": validate_messages(msgs)}
ds = ds.map(convert, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"])
ds = ds.remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
def load_pentest():
print("Loading pentest-agent...")
try:
ds = load_dataset("jason-oneal/pentest-agent-dataset", data_files="chatml_train.jsonl", split="train")
print(f" Raw: {len(ds)}")
def fix_messages(ex):
msgs = ex.get("messages", [])
if not msgs:
return {"messages": [], "valid": False}
# Ensure system prompt
new_msgs = []
has_system = False
for m in msgs:
if isinstance(m, dict) and "role" in m and "content" in m:
role = str(m["role"]).strip().lower()
content = str(m["content"]).strip() if m["content"] else ""
if role == "system":
has_system = True
new_msgs.append({"role": "system", "content": PENTEST_SYSTEM_PROMPT})
elif role in ["user", "assistant"] and content:
new_msgs.append({"role": role, "content": content})
if not has_system:
new_msgs.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT})
return {"messages": new_msgs, "valid": validate_messages(new_msgs)}
ds = ds.map(fix_messages, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"])
ds = ds.remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
except Exception as e:
print(f" Error: {e}")
return None
# Load datasets
print("=" * 50)
print("LOADING AND VALIDATING DATASETS")
print("=" * 50)
all_ds = []
try:
all_ds.append(load_trendyol())
except Exception as e:
print(f"Trendyol error: {e}")
try:
all_ds.append(load_fenrir())
except Exception as e:
print(f"Fenrir error: {e}")
try:
pds = load_pentest()
if pds and len(pds) > 0:
all_ds.append(pds)
except Exception as e:
print(f"Pentest error: {e}")
print(f"\nCombining {len(all_ds)} datasets...")
combined = concatenate_datasets(all_ds)
combined = combined.shuffle(seed=42)
print(f"Total valid examples: {len(combined)}")
# Split
split = combined.train_test_split(test_size=0.02, seed=42)
train_ds = split["train"]
eval_ds = split["test"]
print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
# Verify a sample
print("\nSample message structure:")
sample = train_ds[0]["messages"]
for m in sample:
print(f" {m['role']}: {m['content'][:50]}...")
# Training
print("\n" + "=" * 50)
print("TRAINING")
print("=" * 50)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
config = SFTConfig(
output_dir="qwen2.5-coder-3b-pentest",
push_to_hub=True,
hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
hub_strategy="every_save",
num_train_epochs=2,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-4,
max_length=2048,
gradient_checkpointing=True,
bf16=True,
logging_steps=25,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
eval_strategy="steps",
eval_steps=500,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
report_to="trackio",
project="pentest-agent",
run_name="qwen-3b-cybersec-v3",
)
peft_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
print("Initializing trainer...")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-Coder-3B",
train_dataset=train_ds,
eval_dataset=eval_ds,
args=config,
peft_config=peft_config,
)
print("Starting training...")
trainer.train()
trainer.push_to_hub()
print("\n" + "=" * 50)
print("COMPLETE!")
print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
print("=" * 50)