training-scripts / train_pentest_v2.py
fawazo's picture
Upload train_pentest_v2.py with huggingface_hub
b3589b2 verified
raw
history blame
5.41 kB
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "trackio",
# "datasets>=2.14.0",
# ]
# ///
import json
import traceback
from datasets import load_dataset, concatenate_datasets, Dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
# Custom system prompt for pentesting JSON output
PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI assistant. Analyze web traffic and respond with JSON only.
Response formats:
1. Vulnerability found:
{"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE|LFI|XXE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
2. Send follow-up request:
{"action": "request", "method": "GET|POST", "path": "/...", "headers": {}, "body": "", "reasoning": "..."}
3. Run command:
{"action": "command", "cmd": "...", "reasoning": "..."}
4. Analysis complete:
{"action": "complete", "summary": "...", "tested": ["..."]}
Respond with ONLY valid JSON."""
def load_trendyol():
print("Loading Trendyol dataset...")
ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
print(f" Loaded {len(ds)} examples")
def convert(example):
return {
"messages": [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": example["user"]},
{"role": "assistant", "content": example["assistant"]}
]
}
return ds.map(convert, remove_columns=ds.column_names)
def load_fenrir():
print("Loading Fenrir v2.0 dataset...")
ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
print(f" Loaded {len(ds)} examples")
def convert(example):
return {
"messages": [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": example["user"]},
{"role": "assistant", "content": example["assistant"]}
]
}
cols = [c for c in ds.column_names]
return ds.map(convert, remove_columns=cols)
def load_pentest():
print("Loading pentest-agent dataset...")
try:
ds = load_dataset("jason-oneal/pentest-agent-dataset", data_files="chatml_train.jsonl", split="train")
print(f" Loaded {len(ds)} examples")
def update_system(example):
messages = example["messages"]
if messages and len(messages) > 0:
if messages[0]["role"] == "system":
messages[0]["content"] = PENTEST_SYSTEM_PROMPT
else:
messages.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT})
return {"messages": messages}
return ds.map(update_system)
except Exception as e:
print(f" Warning: Could not load pentest-agent: {e}")
return None
# Main execution
print("=" * 50)
print("LOADING DATASETS")
print("=" * 50)
datasets_list = []
try:
ds1 = load_trendyol()
datasets_list.append(ds1)
except Exception as e:
print(f"ERROR loading Trendyol: {e}")
traceback.print_exc()
try:
ds2 = load_fenrir()
datasets_list.append(ds2)
except Exception as e:
print(f"ERROR loading Fenrir: {e}")
traceback.print_exc()
try:
ds3 = load_pentest()
if ds3:
datasets_list.append(ds3)
except Exception as e:
print(f"ERROR loading pentest-agent: {e}")
traceback.print_exc()
if not datasets_list:
raise RuntimeError("No datasets loaded!")
print(f"\nCombining {len(datasets_list)} datasets...")
combined = concatenate_datasets(datasets_list)
print(f"Total: {len(combined)} examples")
combined = combined.shuffle(seed=42)
split_ds = combined.train_test_split(test_size=0.02, seed=42)
train_ds = split_ds["train"]
eval_ds = split_ds["test"]
print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
# Training config
print("\n" + "=" * 50)
print("STARTING TRAINING")
print("=" * 50)
config = SFTConfig(
output_dir="qwen2.5-coder-3b-pentest",
push_to_hub=True,
hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
hub_strategy="every_save",
num_train_epochs=2,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-4,
max_length=2048,
gradient_checkpointing=True,
bf16=True,
logging_steps=25,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
eval_strategy="steps",
eval_steps=500,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
report_to="trackio",
project="pentest-agent",
run_name="qwen-3b-cybersec-150k",
)
peft_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
print("Loading model Qwen/Qwen2.5-Coder-3B...")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-Coder-3B",
train_dataset=train_ds,
eval_dataset=eval_ds,
args=config,
peft_config=peft_config,
)
print("Training...")
trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub()
print("\n" + "=" * 50)
print("TRAINING COMPLETE!")
print("Model: https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
print("=" * 50)