# /// script # dependencies = [ # "trl>=0.12.0", # "peft>=0.7.0", # "transformers>=4.36.0", # "accelerate>=0.24.0", # "trackio", # "datasets>=2.14.0", # ] # /// import traceback from datasets import load_dataset, concatenate_datasets PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only. Formats: 1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}} 2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."} 3. {"action": "command", "cmd": "...", "reasoning": "..."} 4. {"action": "complete", "summary": "..."} Respond with ONLY valid JSON.""" def validate_messages(messages): """Check if messages are valid for chat template""" if not messages or not isinstance(messages, list): return False if len(messages) < 2: return False for msg in messages: if not isinstance(msg, dict): return False if "role" not in msg or "content" not in msg: return False if not msg["content"] or not isinstance(msg["content"], str): return False if msg["role"] not in ["system", "user", "assistant"]: return False if len(msg["content"].strip()) < 5: return False return True def load_trendyol(): print("Loading Trendyol...") ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train") print(f" Raw: {len(ds)}") def convert(ex): msgs = [ {"role": "system", "content": PENTEST_SYSTEM_PROMPT}, {"role": "user", "content": str(ex["user"]).strip()}, {"role": "assistant", "content": str(ex["assistant"]).strip()} ] return {"messages": msgs, "valid": validate_messages(msgs)} ds = ds.map(convert, remove_columns=ds.column_names) ds = ds.filter(lambda x: x["valid"]) ds = ds.remove_columns(["valid"]) print(f" Valid: {len(ds)}") return ds def load_fenrir(): print("Loading Fenrir...") ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train") print(f" Raw: {len(ds)}") def convert(ex): msgs = [ {"role": "system", "content": PENTEST_SYSTEM_PROMPT}, {"role": "user", "content": str(ex["user"]).strip()}, {"role": "assistant", "content": str(ex["assistant"]).strip()} ] return {"messages": msgs, "valid": validate_messages(msgs)} ds = ds.map(convert, remove_columns=ds.column_names) ds = ds.filter(lambda x: x["valid"]) ds = ds.remove_columns(["valid"]) print(f" Valid: {len(ds)}") return ds def load_pentest(): print("Loading pentest-agent...") try: ds = load_dataset("jason-oneal/pentest-agent-dataset", data_files="chatml_train.jsonl", split="train") print(f" Raw: {len(ds)}") def fix_messages(ex): msgs = ex.get("messages", []) if not msgs: return {"messages": [], "valid": False} # Ensure system prompt new_msgs = [] has_system = False for m in msgs: if isinstance(m, dict) and "role" in m and "content" in m: role = str(m["role"]).strip().lower() content = str(m["content"]).strip() if m["content"] else "" if role == "system": has_system = True new_msgs.append({"role": "system", "content": PENTEST_SYSTEM_PROMPT}) elif role in ["user", "assistant"] and content: new_msgs.append({"role": role, "content": content}) if not has_system: new_msgs.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT}) return {"messages": new_msgs, "valid": validate_messages(new_msgs)} ds = ds.map(fix_messages, remove_columns=ds.column_names) ds = ds.filter(lambda x: x["valid"]) ds = ds.remove_columns(["valid"]) print(f" Valid: {len(ds)}") return ds except Exception as e: print(f" Error: {e}") return None # Load datasets print("=" * 50) print("LOADING AND VALIDATING DATASETS") print("=" * 50) all_ds = [] try: all_ds.append(load_trendyol()) except Exception as e: print(f"Trendyol error: {e}") try: all_ds.append(load_fenrir()) except Exception as e: print(f"Fenrir error: {e}") try: pds = load_pentest() if pds and len(pds) > 0: all_ds.append(pds) except Exception as e: print(f"Pentest error: {e}") print(f"\nCombining {len(all_ds)} datasets...") combined = concatenate_datasets(all_ds) combined = combined.shuffle(seed=42) print(f"Total valid examples: {len(combined)}") # Split split = combined.train_test_split(test_size=0.02, seed=42) train_ds = split["train"] eval_ds = split["test"] print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}") # Verify a sample print("\nSample message structure:") sample = train_ds[0]["messages"] for m in sample: print(f" {m['role']}: {m['content'][:50]}...") # Training print("\n" + "=" * 50) print("TRAINING") print("=" * 50) from peft import LoraConfig from trl import SFTTrainer, SFTConfig config = SFTConfig( output_dir="qwen2.5-coder-3b-pentest", push_to_hub=True, hub_model_id="fawazo/qwen2.5-coder-3b-pentest", hub_strategy="every_save", num_train_epochs=2, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=1e-4, max_length=2048, gradient_checkpointing=True, bf16=True, logging_steps=25, save_strategy="steps", save_steps=500, save_total_limit=2, eval_strategy="steps", eval_steps=500, warmup_ratio=0.03, lr_scheduler_type="cosine", report_to="trackio", project="pentest-agent", run_name="qwen-3b-cybersec-v3", ) peft_config = LoraConfig( r=32, lora_alpha=64, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) print("Initializing trainer...") trainer = SFTTrainer( model="Qwen/Qwen2.5-Coder-3B", train_dataset=train_ds, eval_dataset=eval_ds, args=config, peft_config=peft_config, ) print("Starting training...") trainer.train() trainer.push_to_hub() print("\n" + "=" * 50) print("COMPLETE!") print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest") print("=" * 50)