| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import traceback |
| from datasets import load_dataset, concatenate_datasets |
|
|
| PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only. |
| |
| Formats: |
| 1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}} |
| 2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."} |
| 3. {"action": "command", "cmd": "...", "reasoning": "..."} |
| 4. {"action": "complete", "summary": "..."} |
| |
| Respond with ONLY valid JSON.""" |
|
|
| def validate_messages(messages): |
| """Check if messages are valid for chat template""" |
| if not messages or not isinstance(messages, list): |
| return False |
| if len(messages) < 2: |
| return False |
| for msg in messages: |
| if not isinstance(msg, dict): |
| return False |
| if "role" not in msg or "content" not in msg: |
| return False |
| if not msg["content"] or not isinstance(msg["content"], str): |
| return False |
| if msg["role"] not in ["system", "user", "assistant"]: |
| return False |
| if len(msg["content"].strip()) < 5: |
| return False |
| return True |
|
|
| def load_trendyol(): |
| print("Loading Trendyol...") |
| ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train") |
| print(f" Raw: {len(ds)}") |
| |
| def convert(ex): |
| msgs = [ |
| {"role": "system", "content": PENTEST_SYSTEM_PROMPT}, |
| {"role": "user", "content": str(ex["user"]).strip()}, |
| {"role": "assistant", "content": str(ex["assistant"]).strip()} |
| ] |
| return {"messages": msgs, "valid": validate_messages(msgs)} |
| |
| ds = ds.map(convert, remove_columns=ds.column_names) |
| ds = ds.filter(lambda x: x["valid"]) |
| ds = ds.remove_columns(["valid"]) |
| print(f" Valid: {len(ds)}") |
| return ds |
|
|
| def load_fenrir(): |
| print("Loading Fenrir...") |
| ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train") |
| print(f" Raw: {len(ds)}") |
| |
| def convert(ex): |
| msgs = [ |
| {"role": "system", "content": PENTEST_SYSTEM_PROMPT}, |
| {"role": "user", "content": str(ex["user"]).strip()}, |
| {"role": "assistant", "content": str(ex["assistant"]).strip()} |
| ] |
| return {"messages": msgs, "valid": validate_messages(msgs)} |
| |
| ds = ds.map(convert, remove_columns=ds.column_names) |
| ds = ds.filter(lambda x: x["valid"]) |
| ds = ds.remove_columns(["valid"]) |
| print(f" Valid: {len(ds)}") |
| return ds |
|
|
| def load_pentest(): |
| print("Loading pentest-agent...") |
| try: |
| ds = load_dataset("jason-oneal/pentest-agent-dataset", data_files="chatml_train.jsonl", split="train") |
| print(f" Raw: {len(ds)}") |
| |
| def fix_messages(ex): |
| msgs = ex.get("messages", []) |
| if not msgs: |
| return {"messages": [], "valid": False} |
| |
| |
| new_msgs = [] |
| has_system = False |
| for m in msgs: |
| if isinstance(m, dict) and "role" in m and "content" in m: |
| role = str(m["role"]).strip().lower() |
| content = str(m["content"]).strip() if m["content"] else "" |
| if role == "system": |
| has_system = True |
| new_msgs.append({"role": "system", "content": PENTEST_SYSTEM_PROMPT}) |
| elif role in ["user", "assistant"] and content: |
| new_msgs.append({"role": role, "content": content}) |
| |
| if not has_system: |
| new_msgs.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT}) |
| |
| return {"messages": new_msgs, "valid": validate_messages(new_msgs)} |
| |
| ds = ds.map(fix_messages, remove_columns=ds.column_names) |
| ds = ds.filter(lambda x: x["valid"]) |
| ds = ds.remove_columns(["valid"]) |
| print(f" Valid: {len(ds)}") |
| return ds |
| except Exception as e: |
| print(f" Error: {e}") |
| return None |
|
|
| |
| print("=" * 50) |
| print("LOADING AND VALIDATING DATASETS") |
| print("=" * 50) |
|
|
| all_ds = [] |
|
|
| try: |
| all_ds.append(load_trendyol()) |
| except Exception as e: |
| print(f"Trendyol error: {e}") |
|
|
| try: |
| all_ds.append(load_fenrir()) |
| except Exception as e: |
| print(f"Fenrir error: {e}") |
|
|
| try: |
| pds = load_pentest() |
| if pds and len(pds) > 0: |
| all_ds.append(pds) |
| except Exception as e: |
| print(f"Pentest error: {e}") |
|
|
| print(f"\nCombining {len(all_ds)} datasets...") |
| combined = concatenate_datasets(all_ds) |
| combined = combined.shuffle(seed=42) |
| print(f"Total valid examples: {len(combined)}") |
|
|
| |
| split = combined.train_test_split(test_size=0.02, seed=42) |
| train_ds = split["train"] |
| eval_ds = split["test"] |
| print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}") |
|
|
| |
| print("\nSample message structure:") |
| sample = train_ds[0]["messages"] |
| for m in sample: |
| print(f" {m['role']}: {m['content'][:50]}...") |
|
|
| |
| print("\n" + "=" * 50) |
| print("TRAINING") |
| print("=" * 50) |
|
|
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
|
|
| config = SFTConfig( |
| output_dir="qwen2.5-coder-3b-pentest", |
| push_to_hub=True, |
| hub_model_id="fawazo/qwen2.5-coder-3b-pentest", |
| hub_strategy="every_save", |
| |
| num_train_epochs=2, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=8, |
| learning_rate=1e-4, |
| max_length=2048, |
| |
| gradient_checkpointing=True, |
| bf16=True, |
| |
| logging_steps=25, |
| save_strategy="steps", |
| save_steps=500, |
| save_total_limit=2, |
| |
| eval_strategy="steps", |
| eval_steps=500, |
| |
| warmup_ratio=0.03, |
| lr_scheduler_type="cosine", |
| |
| report_to="trackio", |
| project="pentest-agent", |
| run_name="qwen-3b-cybersec-v3", |
| ) |
|
|
| peft_config = LoraConfig( |
| r=32, |
| lora_alpha=64, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| ) |
|
|
| print("Initializing trainer...") |
| trainer = SFTTrainer( |
| model="Qwen/Qwen2.5-Coder-3B", |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| args=config, |
| peft_config=peft_config, |
| ) |
|
|
| print("Starting training...") |
| trainer.train() |
| trainer.push_to_hub() |
|
|
| print("\n" + "=" * 50) |
| print("COMPLETE!") |
| print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest") |
| print("=" * 50) |
|
|