File size: 4,712 Bytes
82b7c62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | # /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "trackio",
# "datasets>=2.14.0",
# "bitsandbytes",
# ]
# ///
from datasets import load_dataset, concatenate_datasets
PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only.
Formats:
1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."}
3. {"action": "command", "cmd": "...", "reasoning": "..."}
4. {"action": "complete", "summary": "..."}
Respond with ONLY valid JSON."""
def validate_messages(messages):
if not messages or not isinstance(messages, list) or len(messages) < 2:
return False
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
return False
if not msg["content"] or not isinstance(msg["content"], str) or len(msg["content"].strip()) < 5:
return False
if msg["role"] not in ["system", "user", "assistant"]:
return False
return True
def load_trendyol():
print("Loading Trendyol...")
ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
print(f" Raw: {len(ds)}")
def convert(ex):
msgs = [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": str(ex["user"]).strip()},
{"role": "assistant", "content": str(ex["assistant"]).strip()}
]
return {"messages": msgs, "valid": validate_messages(msgs)}
ds = ds.map(convert, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
def load_fenrir():
print("Loading Fenrir...")
ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
print(f" Raw: {len(ds)}")
def convert(ex):
msgs = [
{"role": "system", "content": PENTEST_SYSTEM_PROMPT},
{"role": "user", "content": str(ex["user"]).strip()},
{"role": "assistant", "content": str(ex["assistant"]).strip()}
]
return {"messages": msgs, "valid": validate_messages(msgs)}
ds = ds.map(convert, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["valid"]).remove_columns(["valid"])
print(f" Valid: {len(ds)}")
return ds
print("=" * 50)
print("LOADING DATASETS")
print("=" * 50)
all_ds = []
try:
all_ds.append(load_trendyol())
except Exception as e:
print(f"Trendyol error: {e}")
try:
all_ds.append(load_fenrir())
except Exception as e:
print(f"Fenrir error: {e}")
combined = concatenate_datasets(all_ds).shuffle(seed=42)
print(f"\nTotal: {len(combined)}")
split = combined.train_test_split(test_size=0.02, seed=42)
train_ds, eval_ds = split["train"], split["test"]
print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
print("\n" + "=" * 50)
print("TRAINING WITH MEMORY OPTIMIZATION")
print("=" * 50)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
config = SFTConfig(
output_dir="qwen2.5-coder-3b-pentest",
push_to_hub=True,
hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
hub_strategy="every_save",
num_train_epochs=2,
# MEMORY FIX: batch=1, accumulation=16
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=1e-4,
# MEMORY FIX: shorter sequences
max_length=1024,
# MEMORY FIX: all optimizations enabled
gradient_checkpointing=True,
bf16=True,
optim="adamw_8bit",
logging_steps=25,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
eval_strategy="steps",
eval_steps=1000,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
report_to="trackio",
project="pentest-agent",
run_name="qwen-3b-cybersec-v4-memfix",
)
# Smaller LoRA for memory
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
print("Loading model...")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-Coder-3B",
train_dataset=train_ds,
eval_dataset=eval_ds,
args=config,
peft_config=peft_config,
)
print("Training...")
trainer.train()
trainer.push_to_hub()
print("\n" + "=" * 50)
print("COMPLETE!")
print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
print("=" * 50)
|