training-scripts / train_pentest_v4_memfix.py
fawazo's picture
Upload train_pentest_v4_memfix.py with huggingface_hub
82b7c62 verified
raw
history blame
4.71 kB
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "trackio",
# "datasets>=2.14.0",
# "bitsandbytes",
# ]
# ///
from datasets import load_dataset, concatenate_datasets
PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only.
Formats:
1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."}
3. {"action": "command", "cmd": "...", "reasoning": "..."}
4. {"action": "complete", "summary": "..."}
Respond with ONLY valid JSON."""
def validate_messages(messages):
if not messages or not isinstance(messages, list) or len(messages) < 2:
return False
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
return False
if not msg["content"] or not isinstance(msg["content"], str) or len(msg["content"].strip()) < 5:
return False
if msg["role"] not in ["system", "user", "assistant"]:
return False
return True
def load_trendyol():
print("Loading Trendyol...")
ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
print(f" Raw: {len(ds)}")
def convert(ex):
msgs = [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": str(ex["user"]).strip()},
{"role": "assistant", "content": str(ex["assistant"]).strip()}
]
return {"messages": msgs, "valid": validate_messages(msgs)}
ds = ds.map(convert, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
def load_fenrir():
print("Loading Fenrir...")
ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
print(f" Raw: {len(ds)}")
def convert(ex):
msgs = [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": str(ex["user"]).strip()},
{"role": "assistant", "content": str(ex["assistant"]).strip()}
]
return {"messages": msgs, "valid": validate_messages(msgs)}
ds = ds.map(convert, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
print("=" * 50)
print("LOADING DATASETS")
print("=" * 50)
all_ds = []
try:
all_ds.append(load_trendyol())
except Exception as e:
print(f"Trendyol error: {e}")
try:
all_ds.append(load_fenrir())
except Exception as e:
print(f"Fenrir error: {e}")
combined = concatenate_datasets(all_ds).shuffle(seed=42)
print(f"\nTotal: {len(combined)}")
split = combined.train_test_split(test_size=0.02, seed=42)
train_ds, eval_ds = split["train"], split["test"]
print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
print("\n" + "=" * 50)
print("TRAINING WITH MEMORY OPTIMIZATION")
print("=" * 50)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
config = SFTConfig(
output_dir="qwen2.5-coder-3b-pentest",
push_to_hub=True,
hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
hub_strategy="every_save",
num_train_epochs=2,
# MEMORY FIX: batch=1, accumulation=16
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=1e-4,
# MEMORY FIX: shorter sequences
max_length=1024,
# MEMORY FIX: all optimizations enabled
gradient_checkpointing=True,
bf16=True,
optim="adamw_8bit",
logging_steps=25,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
eval_strategy="steps",
eval_steps=1000,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
report_to="trackio",
project="pentest-agent",
run_name="qwen-3b-cybersec-v4-memfix",
)
# Smaller LoRA for memory
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
print("Loading model...")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-Coder-3B",
train_dataset=train_ds,
eval_dataset=eval_ds,
args=config,
peft_config=peft_config,
)
print("Training...")
trainer.train()
trainer.push_to_hub()
print("\n" + "=" * 50)
print("COMPLETE!")
print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
print("=" * 50)