100XZX001 commited on
Commit
530fe32
Β·
verified Β·
1 Parent(s): 91038d2

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +763 -745
training.py CHANGED
@@ -1,793 +1,811 @@
1
- # training.py
2
- import torch._dynamo
3
- torch._dynamo.config.disable = True
4
- import json
5
- import os
 
 
 
 
 
 
 
 
6
  import torch
7
  import torch.nn.functional as F
8
  from torch.optim import AdamW
9
- from dataclasses import dataclass
10
- from typing import List, Dict, Tuple, Optional
 
11
  import numpy as np
12
- import re
13
- import random
14
- import matplotlib.pyplot as plt
15
 
 
16
  from unsloth import FastLanguageModel
17
- from transformers import TrainingArguments
18
- from trl import SFTTrainer
19
- from datasets import Dataset
20
 
21
- # Import your environment and actions (unchanged)
22
  from environment import CodeReviewEnv
23
  from redteam import BUG_DB
24
- from models import (
25
- RunTests, RunLinter, Inspect,
26
- ProposeFix, WriteComment, AskQuestion,
27
- Done, Skip , QueryDocs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
- # ======================================================================
31
- # 1. ACTION PARSING (improved with fallback)
32
- # ======================================================================
 
 
 
33
  @dataclass
34
  class AgentAction:
35
  action_type: str
36
  content: Optional[str] = None
37
 
38
- def parse_action(output: str) -> AgentAction:
39
- """Robust JSON parsing with regex fallback and keyword detection."""
40
- # Try strict JSON first
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  try:
42
- data = json.loads(output)
43
- return AgentAction(
44
- action_type=data.get("action_type", "").lower(),
45
- content=data.get("content")
46
- )
47
- except:
48
  pass
49
-
50
- # Try to extract JSON from markdown blocks
51
- json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
52
- if json_match:
53
- try:
54
- data = json.loads(json_match.group(1))
55
- return AgentAction(
56
- action_type=data.get("action_type", "").lower(),
57
- content=data.get("content")
58
- )
59
- except:
60
- pass
61
-
62
- # Try to find "action_type" field with regex
63
- action_pattern = r'"action_type"\s*:\s*"(\w+)"'
64
- match = re.search(action_pattern, output)
65
- if match:
66
- return AgentAction(action_type=match.group(1).lower())
67
-
68
- # Keyword detection as last resort
69
- output_lower = output.lower()
70
- if "test" in output_lower:
71
- return AgentAction("run_tests")
72
- if "lint" in output_lower:
73
- return AgentAction("run_linter")
74
- if "inspect" in output_lower:
75
- return AgentAction("inspect")
76
- if "doc" in output_lower or "documentation" in output_lower:
77
- # Bridge natural language mentions to rltool-backed retrieval action.
78
- return AgentAction("query_docs", "bug fix guidance")
79
-
80
- return AgentAction("invalid", output)
81
 
82
  def map_to_env(action: AgentAction):
83
- if action.action_type == "run_tests":
84
- return RunTests()
85
- elif action.action_type == "run_linter":
86
- return RunLinter()
87
- elif action.action_type == "inspect":
88
- return Inspect()
89
- elif action.action_type == "fix":
90
- return ProposeFix(fix_code=action.content or "")
91
- elif action.action_type == "comment":
92
- return WriteComment(comment_text=action.content or "")
93
- elif action.action_type == "question":
94
- return AskQuestion(question=action.content or "")
95
- elif action.action_type == "query_docs": # <-- new
96
- return QueryDocs(query_topic=action.content or "")
97
- elif action.action_type == "done":
98
- return Done()
99
- else:
100
- return Skip()
101
-
102
- # ======================================================================
103
- # 2. MODEL SETUP (stabilised LoRA)
104
- # ======================================================================
105
  def load_model():
 
106
  model, tokenizer = FastLanguageModel.from_pretrained(
107
- model_name="unsloth/gemma-2-2b-it-bnb-4bit",
108
- max_seq_length=768,
109
- load_in_4bit=True,
110
  )
111
- # FIXED: Lower rank (16), dropout=0 for stability
112
  model = FastLanguageModel.get_peft_model(
113
  model,
114
- r=16, # was 64 β†’ causes collapse
115
- target_modules=[
116
- "q_proj", "k_proj", "v_proj", "o_proj",
117
- "gate_proj", "up_proj", "down_proj"
118
- ],
119
- lora_alpha=32, # adjusted for r=16
120
- lora_dropout=0.0, # dropout can cause empty outputs
121
  )
122
- # Ensure tokenizer has correct chat template for Gemma-2
123
- if tokenizer.chat_template is None:
124
- tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n<start_of_turn>model\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}<end_of_turn>\n{% endif %}{% endfor %}"
125
  return model, tokenizer
126
 
127
- # ======================================================================
128
- # 3. MODEL SANITY CHECK (new – ensures model can generate text)
129
- # ======================================================================
130
- def test_model_sanity(model, tokenizer) -> bool:
131
- print("\n" + "="*60)
132
- print("SANITY CHECK: Testing base model generation")
133
- print("="*60)
134
- test_prompt = "Hello, how are you?"
135
- messages = [{"role": "user", "content": test_prompt}]
136
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
137
- inputs = tokenizer(formatted, return_tensors="pt", max_length=768, truncation=True).to("cuda")
138
- with torch.no_grad():
139
- outputs = model.generate(
140
- **inputs,
141
- max_new_tokens=30,
142
- do_sample=True,
143
- temperature=0.7,
144
- min_new_tokens=1,
145
- eos_token_id=tokenizer.eos_token_id,
146
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
147
- )
148
- generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
149
- response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
150
- print(f"Prompt: {test_prompt}")
151
- print(f"Response: {repr(response)}")
152
- if len(response) == 0:
153
- print("❌ Model produces empty output – cannot train.")
154
- return False
155
- print("βœ“ Model sanity check PASSED\n")
156
- return True
157
-
158
- # ======================================================================
159
- # 4. SUPERVISED WARM-UP (teaches JSON output)
160
- # ======================================================================
161
- def supervised_warmup(model, tokenizer, n_examples=500, epochs=8):
162
- print("\n" + "="*60)
163
- print("SUPERVISED WARM-UP: Teaching JSON format")
164
- print("="*60)
165
-
166
- examples = []
167
- action_templates = [
168
- '{"action_type": "run_tests"}',
169
- '{"action_type": "run_linter"}',
170
- '{"action_type": "inspect"}',
171
- '{"action_type": "fix", "content": "def corrected():\n pass"}',
172
- '{"action_type": "comment", "content": "This looks good."}',
173
- '{"action_type": "question", "content": "Why is this variable used?"}',
174
- '{"action_type": "query_docs", "content": "KeyError"}',
175
- '{"action_type": "done"}',
176
- ]
177
-
178
- for i in range(n_examples):
179
- code = f"def example_{i}():\n return {i % 10}"
180
- last_outputs = [
181
- "Tests passed: 2/3",
182
- "Linter found 1 error",
183
- "Inspection complete",
184
- "No previous action",
185
- ]
186
- last_output = random.choice(last_outputs)
187
- # Use same prompt structure as build_prompt
188
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix.
189
-
190
- The developer has a **defensive** personality and will only accept if you provide solid evidence:
191
- - Tests pass (high pass ratio)
192
- - Lint is clean (zero errors)
193
- - Documentation or references are provided
194
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
195
-
196
- Workflow:
197
- 1. Use `inspect` to understand the code.
198
- 2. Use `run_tests` and `run_linter` to gather evidence.
199
- 3. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
200
- 4. If the developer pushes back, read their response carefully and address their specific concern.
201
- 5. Once convinced, use `done` to finish.
202
-
203
- Code:
204
- {code}
205
-
206
- Author says:
207
- (no response yet – start with inspection)
208
-
209
- Last tool output:
210
- {last_output}
211
-
212
- Available actions:
213
- run_tests, run_linter, inspect, fix, comment, question, done, query_docs
214
-
215
- Respond ONLY in JSON:
216
- {{"action_type": "...", "content": "..."}}"""
217
-
218
- action_json = random.choice(action_templates)
219
- messages = [
220
- {"role": "user", "content": prompt},
221
- {"role": "assistant", "content": action_json}
222
- ]
223
- full_text = tokenizer.apply_chat_template(messages, tokenize=False)
224
- examples.append({"text": full_text})
225
-
226
- dataset = Dataset.from_list(examples)
227
- trainer = SFTTrainer(
228
- model=model,
229
- tokenizer=tokenizer,
230
- train_dataset=dataset,
231
- dataset_text_field="text",
232
- max_seq_length=512,
233
- args=TrainingArguments(
234
- output_dir="warmup_output",
235
- num_train_epochs=epochs,
236
- per_device_train_batch_size=4,
237
- gradient_accumulation_steps=2,
238
- learning_rate=2e-5,
239
- logging_steps=50,
240
- save_strategy="no",
241
- fp16=True,
242
- ),
243
  )
244
- print(f"Training on {n_examples} examples for {epochs} epochs...")
245
- trainer.train()
246
- print("βœ“ Warm-up complete\n")
247
- torch.cuda.empty_cache()
248
- # ======================================================================
249
- # 5. ACTION GENERATION WITH LOGPROB TRACKING (fixed)
250
- # ======================================================================
251
- def generate_action_with_logprob(
252
- prompt: str,
253
- model,
254
- tokenizer,
255
- temperature: float = 0.0, # changed: greedy by default for stability
256
- max_retries: int = 2
257
- ) -> Tuple[str, float]:
258
- """Generate action using correct chat template, with fallback."""
259
- messages = [{"role": "user", "content": prompt}]
260
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
261
- inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
262
-
263
- for attempt in range(max_retries):
264
- with torch.no_grad():
265
- outputs = model.generate(
266
- **inputs,
267
- max_new_tokens=128,
268
- do_sample=(temperature > 0),
269
- temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
270
- min_new_tokens=1,
271
- return_dict_in_generate=True,
272
- output_scores=True,
273
- )
274
-
275
- generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
276
- action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
277
-
278
- # Compute logprob
279
- logprobs = []
280
- for idx, token_id in enumerate(generated_ids):
281
- if idx < len(outputs.scores):
282
- token_logits = outputs.scores[idx][0]
283
- token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
284
- logprobs.append(token_logprob)
285
- total_logprob = sum(logprobs) if logprobs else -100.0
286
-
287
- # If empty, use fallback
288
- if not action_text:
289
- fallback_actions = [
290
- '{"action_type": "run_tests"}',
291
- '{"action_type": "run_linter"}',
292
- '{"action_type": "inspect"}',
293
- '{"action_type": "skip"}',
294
- ]
295
- action_text = random.choice(fallback_actions)
296
- total_logprob = -50.0
297
- print(f"[WARN] Empty generation β†’ using fallback: {action_text}")
298
- return action_text, total_logprob
299
-
300
- # Validate JSON
301
- try:
302
- json.loads(action_text)
303
- return action_text, total_logprob
304
- except:
305
- if attempt == max_retries - 1:
306
- return '{"action_type":"skip"}', -100.0
307
  continue
308
-
309
- return '{"action_type":"skip"}', -100.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- # ======================================================================
312
- # 6. PROMPT BUILDER (unchanged – exactly as you wrote)
313
- # ======================================================================
314
- def build_prompt(obs, history_lines: List[str]) -> str:
315
- author_msg = getattr(obs, "author_response", "") or ""
316
- tool_output = getattr(obs, "last_tool_output", "") or ""
317
-
318
- # Personality hint (optional but helpful)
319
- author_personality = getattr(obs, "author_personality", "defensive") # e.g., from env
320
-
321
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
322
-
323
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
324
- - Tests pass (high pass ratio)
325
- - Lint is clean (zero errors)
326
- - Documentation or references are provided
327
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
328
-
329
- Workflow:
330
- 1. Use `inspect` to understand the code.
331
- 2. Use `run_tests` and `run_linter` to gather evidence.
332
- 3. Use `query_docs` when you need references or language-specific guidance.
333
- 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
334
- 5. If the developer pushes back, read their response carefully and address their specific concern.
335
- 6. Once convinced, use `done` to finish.
336
-
337
- Code:
338
- {obs.code_snippet}
339
-
340
- Author says:
341
- {author_msg if author_msg else "(no response yet – start with inspection)"}
342
-
343
- Last tool output:
344
- {tool_output if tool_output else "(none)"}
345
-
346
- Available actions:
347
- run_tests, run_linter, inspect, query_docs, fix, comment, question, done
348
-
349
- Respond ONLY in JSON:
350
- {{"action_type": "...", "content": "..."}}"""
351
-
352
- if history_lines:
353
- history = "\n".join(history_lines[-6:])
354
- prompt += f"\n\nPrevious steps:\n{history}"
355
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- # ======================================================================
358
- # 7. TRAJECTORY STORAGE (unchanged)
359
- # ======================================================================
360
- @dataclass
361
- class Trajectory:
362
- states: List[str]
363
- actions: List[str]
364
- rewards: List[float]
365
- logprobs: List[float]
366
- dones: List[bool]
367
-
368
- def __len__(self):
369
- return len(self.states)
370
-
371
- def to_dict(self):
372
- return {
373
- "states": self.states,
374
- "actions": self.actions,
375
- "rewards": self.rewards,
376
- "logprobs": self.logprobs,
377
- "dones": self.dones,
378
- }
379
-
380
- # ======================================================================
381
- # 8. ROLLOUT COLLECTION (uses fixed generate)
382
- # ======================================================================
383
- def collect_trajectory(
384
- env: CodeReviewEnv,
385
- model,
386
- tokenizer,
387
- max_steps: int = 10,
388
- temperature: float = 0.0 # changed to greedy
389
- ) -> Trajectory:
390
  obs = env.reset()
391
- history_lines = []
392
-
393
- states = []
394
- actions = []
395
- rewards = []
396
- logprobs = []
397
- dones = []
398
-
399
- for step in range(max_steps):
400
- prompt = build_prompt(obs, history_lines)
401
- states.append(prompt)
402
-
403
- action_text, logprob = generate_action_with_logprob(
404
- prompt, model, tokenizer, temperature
405
- )
406
- actions.append(action_text)
407
- logprobs.append(logprob)
408
-
409
- action = parse_action(action_text)
410
- env_action = map_to_env(action)
411
- next_obs, reward, done, _ = env.step(env_action)
412
-
413
- rewards.append(reward.value)
414
- dones.append(done)
415
-
416
- history_lines.append(f"Agent: {action_text}")
417
- history_lines.append(f"Env: {next_obs.last_tool_output}")
418
-
419
- obs = next_obs
420
- if done:
 
 
 
 
 
 
 
421
  break
422
-
423
- return Trajectory(states, actions, rewards, logprobs, dones)
424
-
425
- def collect_trajectories(
426
- env: CodeReviewEnv,
427
- model,
428
- tokenizer,
429
- n_trajectories: int,
430
- max_steps: int = 10,
431
- task_levels: Optional[List[str]] = None,
432
- task_weights: Optional[List[float]] = None,
433
- ) -> List[Trajectory]:
434
- # Link training to RedTeam's full bug distribution by sampling tasks
435
- # per trajectory instead of training only on env default ("easy").
436
- if task_levels is None:
437
- task_levels = list(BUG_DB.keys())
438
- if task_weights is not None and len(task_weights) != len(task_levels):
439
- raise ValueError("task_weights must match task_levels length")
440
- if task_weights is not None and sum(task_weights) <= 0:
441
- raise ValueError("task_weights must have a positive total")
442
-
443
- trajectories = []
444
- for i in range(n_trajectories):
445
- # Weighted sampling supports curriculum-style training schedules.
446
- sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
447
- env.set_task(sampled_task)
448
- traj = collect_trajectory(env, model, tokenizer, max_steps)
449
- total_reward = sum(traj.rewards)
450
- print(f"Trajectory {i+1}/{n_trajectories}: "
451
- f"task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
452
- trajectories.append(traj)
453
- return trajectories
454
-
455
- # ======================================================================
456
- # 9. ADVANTAGE ESTIMATION (unchanged)
457
- # ======================================================================
458
- def compute_returns_and_advantages(
459
- rewards: List[float],
460
- dones: List[bool],
461
- gamma: float = 0.99,
462
- standardize: bool = True
463
- ) -> Tuple[List[float], List[float]]:
464
- """
465
- Computes discounted returns and normalised advantages (no critic).
466
- Advantages = returns - mean(returns) (or zero baseline).
467
- """
468
- n = len(rewards)
469
- returns = [0.0] * n
470
- running_return = 0.0
471
- for t in reversed(range(n)):
472
- if dones[t]:
473
- running_return = 0.0
474
- running_return = rewards[t] + gamma * running_return
475
- returns[t] = running_return
476
-
477
- if standardize:
478
- advantages = np.array(returns) - np.mean(returns)
479
- adv_std = np.std(advantages) + 1e-8
480
- advantages = (advantages / adv_std).tolist()
481
- else:
482
- advantages = returns.copy()
483
-
484
- return advantages, returns
485
- # ======================================================================
486
- # 10. COMPUTE NEW LOGPROBS (unchanged)
487
- # ======================================================================
488
- def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
489
- messages = [{"role": "user", "content": prompt}]
490
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
491
- full_text = formatted + action
492
- inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
493
-
494
- with torch.no_grad():
495
- outputs = model(**inputs)
496
- logits = outputs.logits
497
-
498
- action_ids = tokenizer.encode(action, add_special_tokens=False)
499
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
500
- action_start = len(prefix_ids)
501
-
502
- logprobs = []
503
- for idx, token_id in enumerate(action_ids):
504
- position = action_start + idx - 1
505
- if 0 <= position < logits.shape[1]:
506
- token_logits = logits[0, position]
507
- token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
508
- logprobs.append(token_logprob)
509
- return sum(logprobs) if logprobs else -100.0
510
-
511
- # ======================================================================
512
- # 11. PPO UPDATE (unchanged except uses compute_logprob correctly)
513
- # ======================================================================
514
- def ppo_update(
515
- trajectories: List[Trajectory],
516
- model,
517
- tokenizer,
518
- optimizer,
519
- n_epochs: int = 4,
520
- clip_epsilon: float = 0.2,
521
- entropy_coef: float = 0.01,
522
- gamma: float = 0.99,
523
- ) -> Dict[str, float]:
524
  model.train()
525
-
526
- all_states = []
527
- all_actions = []
528
- all_old_logprobs = []
529
- all_advantages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  all_returns = []
531
-
532
  for traj in trajectories:
533
- advantages, returns = compute_returns_and_advantages(
534
- traj.rewards, traj.dones, gamma=gamma, standardize=True
535
- )
536
- all_states.extend(traj.states)
537
- all_actions.extend(traj.actions)
538
- all_old_logprobs.extend(traj.logprobs)
539
- all_advantages.extend(advantages)
540
- all_returns.extend(returns)
541
-
542
- n_samples = len(all_states)
543
- total_loss = 0.0
544
- total_policy_loss = 0.0
545
- total_entropy = 0.0
546
- n_updates = 0
547
-
548
- for epoch in range(n_epochs):
549
- indices = np.random.permutation(n_samples)
550
- for i in indices:
551
- state = all_states[i]
552
- action = all_actions[i]
553
- old_logprob = all_old_logprobs[i]
554
- advantage = all_advantages[i]
555
-
556
- # Use the same chat template for PPO update
557
- messages = [{"role": "user", "content": state}]
558
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
559
- full_text = formatted + action
560
- inputs = tokenizer(full_text, return_tensors="pt", max_length=768, truncation=True).to("cuda")
561
-
562
- outputs = model(**inputs)
563
- logits = outputs.logits
564
-
565
- action_ids = tokenizer.encode(action, add_special_tokens=False)
566
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
567
- action_start = len(prefix_ids)
568
-
569
- logprobs = []
570
- entropy = 0.0
571
- for idx, token_id in enumerate(action_ids):
572
- position = action_start + idx - 1
573
- if 0 <= position < logits.shape[1]:
574
- token_logits = logits[0, position]
575
- log_probs = F.log_softmax(token_logits, dim=-1)
576
- token_logprob = log_probs[token_id]
577
- logprobs.append(token_logprob)
578
-
579
- probs = F.softmax(token_logits, dim=-1)
580
- entropy += -(probs * log_probs).sum()
581
-
582
- if not logprobs:
 
 
 
 
 
 
 
 
 
 
 
583
  continue
584
-
585
- new_logprob = sum(logprobs)
586
- avg_entropy = entropy / len(logprobs) if logprobs else 0.0
587
-
588
- ratio = torch.exp(new_logprob - old_logprob)
589
- surr1 = ratio * advantage
590
- surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
591
- policy_loss = -torch.min(surr1, surr2)
592
- loss = policy_loss - entropy_coef * avg_entropy
593
-
 
 
 
 
 
 
 
 
 
 
594
  optimizer.zero_grad()
595
  loss.backward()
596
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
597
  optimizer.step()
598
-
599
- total_loss += loss.item()
600
- total_policy_loss += policy_loss.item()
601
- total_entropy += avg_entropy.item()
602
- n_updates += 1
603
-
604
- return {
605
- "loss": total_loss / n_updates if n_updates > 0 else 0.0,
606
- "policy_loss": total_policy_loss / n_updates if n_updates > 0 else 0.0,
607
- "entropy": total_entropy / n_updates if n_updates > 0 else 0.0,
608
- }
609
- # ======================================================================
610
- # 12. EVALUATION (unchanged)
611
- # ======================================================================
612
- def evaluate_policy(
613
- env: CodeReviewEnv,
614
- model,
615
- tokenizer,
616
- n_episodes: int = 10,
617
- max_steps: int = 10
618
- ) -> Dict[str, float]:
619
- model.eval()
620
- total_rewards = []
621
- episode_lengths = []
622
- success_count = 0
623
-
624
- for _ in range(n_episodes):
625
- traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
626
- total_reward = sum(traj.rewards)
627
- total_rewards.append(total_reward)
628
- episode_lengths.append(len(traj))
629
- if total_reward > 0.5:
630
- success_count += 1
631
-
632
- return {
633
- "avg_reward": np.mean(total_rewards),
634
- "std_reward": np.std(total_rewards),
635
- "avg_length": np.mean(episode_lengths),
636
- "success_rate": success_count / n_episodes,
637
- }
638
-
639
- # ======================================================================
640
- # 13. MAIN TRAINING LOOP (added sanity check and warm-up)
641
- # ======================================================================
642
- def train_ppo(
643
- n_iterations: int = 50,
644
- trajectories_per_iter: int = 10,
645
- n_epochs: int = 2,
646
- max_steps: int = 10,
647
- learning_rate: float = 3e-5,
648
- clip_epsilon: float = 0.2,
649
- entropy_coef: float = 0.01,
650
- gamma: float = 0.99,
651
- eval_every: int = 5,
652
- task_levels: Optional[List[str]] = None,
653
- curriculum_weighted_sampling: bool = True,
654
- reward_profile: str = "full",
655
- ):
656
- print("Loading model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  model, tokenizer = load_model()
658
-
659
- # NEW: Sanity check before any training
660
- if not test_model_sanity(model, tokenizer):
661
- print("\n❌ Model sanity check failed – cannot proceed.")
662
- return
663
-
664
- # NEW: Supervised warm-up to teach JSON format (500 steps with epochs=8)
665
- supervised_warmup(model, tokenizer, n_examples=500, epochs=8)
666
-
667
- optimizer = AdamW(model.parameters(), lr=learning_rate)
668
  env = CodeReviewEnv()
669
- if task_levels is None:
670
- task_levels = list(BUG_DB.keys())
671
-
672
- print(f"\n{'='*60}")
673
- print(f"Starting PPO Training")
674
- print(f"Iterations: {n_iterations}")
675
- print(f"Trajectories per iteration: {trajectories_per_iter}")
676
- print(f"PPO epochs: {n_epochs}")
677
- print(f"Reward profile: {reward_profile}")
678
- print(f"{'='*60}\n")
679
- reward_history: List[float] = []
680
- loss_history: List[float] = []
681
-
682
- for iteration in range(n_iterations):
683
- print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
684
- # Optional weighted curriculum:
685
- # start with easier tasks and smoothly ramp difficulty over training.
686
- if curriculum_weighted_sampling:
687
- progress = (iteration + 1) / max(n_iterations, 1)
688
- easy_w = max(0.15, 0.55 - 0.40 * progress)
689
- medium_w = max(0.15, 0.25 - 0.10 * progress)
690
- hard_w = 0.10 + 0.05 * progress
691
- harder_w = 0.05 + 0.20 * progress
692
- hardest_w = 0.05 + 0.25 * progress
693
- task_weight_map = {
694
- "easy": easy_w,
695
- "medium": medium_w,
696
- "hard": hard_w,
697
- "harder": harder_w,
698
- "hardest": hardest_w,
699
- }
700
- task_weights = [task_weight_map.get(level, 1.0) for level in task_levels]
701
- else:
702
- task_weights = None
703
-
704
- print("Collecting trajectories...")
705
- trajectories = collect_trajectories(
706
- env,
707
- model,
708
- tokenizer,
709
- trajectories_per_iter,
710
- max_steps,
711
- task_levels=task_levels,
712
- task_weights=task_weights,
713
- )
714
-
715
- avg_reward = np.mean([sum(t.rewards) for t in trajectories])
716
- avg_length = np.mean([len(t) for t in trajectories])
717
- reward_history.append(float(avg_reward))
718
-
719
- print(f"Avg reward: {avg_reward:.3f}")
720
- print(f"Avg length: {avg_length:.1f}")
721
-
722
- print("Updating policy...")
723
- metrics = ppo_update(
724
- trajectories,
725
- model,
726
- tokenizer,
727
- optimizer,
728
- n_epochs=n_epochs,
729
- clip_epsilon=clip_epsilon,
730
- entropy_coef=entropy_coef,
731
- gamma=gamma,
732
- )
733
-
734
- print(f"Loss: {metrics['loss']:.4f}")
735
- print(f"Policy loss: {metrics['policy_loss']:.4f}")
736
- print(f"Entropy: {metrics['entropy']:.4f}")
737
- loss_history.append(float(metrics["loss"]))
738
-
739
- if (iteration + 1) % eval_every == 0:
740
- print("\nEvaluating policy...")
741
- eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
742
- print(f"Eval avg reward: {eval_metrics['avg_reward']:.3f} Β± {eval_metrics['std_reward']:.3f}")
743
- print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
744
- print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
745
-
746
  print("\n" + "="*60)
747
- print("Training complete. Saving model...")
748
- model.save_pretrained("ppo_final_model")
749
- tokenizer.save_pretrained("ppo_final_model")
750
- print("Model saved to ppo_final_model/")
751
-
752
- # Save training curves for quick before/after comparisons.
753
- if reward_history:
754
- plt.figure(figsize=(8, 4))
755
- plt.plot(range(1, len(reward_history) + 1), reward_history, marker="o")
756
- plt.title("Average Reward per Iteration")
757
- plt.xlabel("Iteration")
758
- plt.ylabel("Average Reward")
759
- plt.grid(alpha=0.3)
760
- plt.tight_layout()
761
- plt.savefig("reward_curve.png", dpi=150)
762
- plt.close()
763
-
764
- if loss_history:
765
- plt.figure(figsize=(8, 4))
766
- plt.plot(range(1, len(loss_history) + 1), loss_history, marker="o", color="tab:red")
767
- plt.title("Training Loss per Iteration")
768
- plt.xlabel("Iteration")
769
- plt.ylabel("Loss")
770
- plt.grid(alpha=0.3)
771
- plt.tight_layout()
772
- plt.savefig("loss_curve.png", dpi=150)
773
- plt.close()
774
-
775
- if os.path.exists("reward_curve.png") and os.path.exists("loss_curve.png"):
776
- print("Saved reward_curve.png and loss_curve.png")
777
  print("="*60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
 
779
- # ======================================================================
780
- # 14. ENTRY POINT (unchanged)
781
- # ======================================================================
782
  if __name__ == "__main__":
783
- train_ppo(
784
- n_iterations=30,
785
- trajectories_per_iter=10,
786
- n_epochs=4,
787
- max_steps=10,
788
- learning_rate=3e-5,
789
- clip_epsilon=0.2,
790
- entropy_coef=0.01,
791
- gamma=0.99,
792
- eval_every=5,
793
- )
 
1
+ # training.py – PPO + QLoRA + Supervised Warm-up
2
+ # Model : Qwen/Qwen2.5-1.5B-Instruct (via Unsloth – 2Γ— faster, fits Colab T4)
3
+ # Fixed : label-masking, BPE-boundary alignment, log-ratio clamping, OOM guards
4
+ # Evidence: reward curves, before/after traces, per-difficulty breakdown, KL, entropy
5
+ # ============================================================
6
+ import os, json, random, re
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.gridspec as gridspec
13
+
14
  import torch
15
  import torch.nn.functional as F
16
  from torch.optim import AdamW
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Optional, Dict
19
+ from collections import Counter, defaultdict
20
  import numpy as np
 
 
 
21
 
22
+ # ── Unsloth gives 2Γ— throughput with identical outputs ────────────────────────
23
  from unsloth import FastLanguageModel
 
 
 
24
 
 
25
  from environment import CodeReviewEnv
26
  from redteam import BUG_DB
27
+
28
+ # Graceful import: use project map_to_env if available, else inline fallback.
29
+ try:
30
+ from models import map_to_env as model_map_to_env
31
+ _HAVE_MODEL_MAP = True
32
+ except (ImportError, AttributeError):
33
+ _HAVE_MODEL_MAP = False
34
+
35
+ if not _HAVE_MODEL_MAP:
36
+ try:
37
+ from models import (RunTests, RunLinter, Inspect, ProposeFix,
38
+ WriteComment, AskQuestion, Done, Skip, QueryDocs)
39
+ def model_map_to_env(action_type: str, content=None):
40
+ return {
41
+ "run_tests": RunTests(),
42
+ "run_linter": RunLinter(),
43
+ "inspect": Inspect(),
44
+ "query_docs": QueryDocs(content or "python bug fix"),
45
+ "fix": ProposeFix(content or ""),
46
+ "comment": WriteComment(content or ""),
47
+ "question": AskQuestion(content or ""),
48
+ "done": Done(),
49
+ }.get(action_type, Skip())
50
+ except ImportError:
51
+ # Last resort: duck-typed object the env can introspect.
52
+ class _EnvAction:
53
+ def __init__(self, **kw): self.__dict__.update(kw)
54
+ def model_map_to_env(action_type: str, content=None):
55
+ return _EnvAction(action_type=action_type, content=content)
56
+
57
+ # ══════════════════════════════════════════════════════════════════════════════
58
+ # CONFIG
59
+ # ══════════════════════════════════════════════════════════════════════════════
60
+ CFG = dict(
61
+ model_name = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
62
+ max_seq_len = 512, # hard cap; prevents OOM on T4
63
+ lora_r = 16,
64
+ lora_alpha = 32,
65
+
66
+ # Warm-up
67
+ warmup_data = "training_data.json",
68
+ warmup_epochs = 2,
69
+ warmup_lr = 2e-5,
70
+ warmup_grad_acc = 4, # effective batch = 4 examples
71
+
72
+ # PPO
73
+ ppo_iters = 15,
74
+ trajs_per_iter = 6,
75
+ max_steps = 7,
76
+ ppo_lr = 3e-5,
77
+ clip_eps = 0.2,
78
+ entropy_coef = 0.01,
79
+ gamma = 0.99,
80
+ log_ratio_clamp = 5.0, # ← prevents exp-explosion / NaN loss
81
+ temp_start = 0.8,
82
+ temp_end = 0.1,
83
+
84
+ # Eval
85
+ eval_episodes = 10, # episodes per evaluation snapshot
86
  )
87
 
88
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
89
+ TASK_LEVELS = list(BUG_DB.keys()) # [easy, medium, hard, harder, hardest]
90
+
91
+ # ══════════════════════════════════════════════════════════════════════════════
92
+ # DATA STRUCTURES
93
+ # ══════════════════════════════════════════════════════════════════════════════
94
  @dataclass
95
  class AgentAction:
96
  action_type: str
97
  content: Optional[str] = None
98
 
99
+ @dataclass
100
+ class Trajectory:
101
+ states: List[str]
102
+ actions: List[str]
103
+ rewards: List[float]
104
+ logprobs: List[float]
105
+ dones: List[bool]
106
+ task: str = ""
107
+
108
+ @dataclass
109
+ class EvalSnapshot:
110
+ """Captures full agent behaviour for before/after comparison."""
111
+ avg_reward: float
112
+ per_task: Dict[str, float] = field(default_factory=dict)
113
+ action_dist: Dict[str, float] = field(default_factory=dict)
114
+ success_rate: float = 0.0
115
+ avg_steps: float = 0.0
116
+ traces: List[dict] = field(default_factory=list)
117
+
118
+ # ══════════════════════════════════════════════════════════════════════════════
119
+ # ACTION PARSER
120
+ # ══════════════════════════════════════════════════════════════════════════════
121
+ def parse_action(text: str) -> AgentAction:
122
+ """Robust parser: tries strict JSON, then regex, then keyword heuristic."""
123
+ text = text.strip()
124
  try:
125
+ d = json.loads(text)
126
+ return AgentAction(d.get("action_type","skip").lower(), d.get("content"))
127
+ except json.JSONDecodeError:
 
 
 
128
  pass
129
+ m = re.search(r'"action_type"\s*:\s*"(\w+)"', text)
130
+ if m:
131
+ cm = re.search(r'"content"\s*:\s*"(.*?)"', text, re.DOTALL)
132
+ return AgentAction(m.group(1).lower(), cm.group(1) if cm else None)
133
+ tl = text.lower()
134
+ for kw in ("run_tests","run_linter","inspect","query_docs","fix",
135
+ "comment","question","done"):
136
+ if kw in tl:
137
+ return AgentAction(kw)
138
+ return AgentAction("skip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def map_to_env(action: AgentAction):
141
+ return model_map_to_env(action.action_type, action.content)
142
+
143
+ # ══════════════════════════════════════════════════════════════════════════════
144
+ # MODEL (Qwen2.5-1.5B via Unsloth)
145
+ # ══════════════════════════════════════════════════════════════════════════════
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def load_model():
147
+ print(f"Loading {CFG['model_name']} …")
148
  model, tokenizer = FastLanguageModel.from_pretrained(
149
+ model_name = CFG["model_name"],
150
+ max_seq_length = CFG["max_seq_len"],
151
+ load_in_4bit = True,
152
  )
 
153
  model = FastLanguageModel.get_peft_model(
154
  model,
155
+ r = CFG["lora_r"],
156
+ lora_alpha = CFG["lora_alpha"],
157
+ target_modules = ["q_proj","k_proj","v_proj","o_proj",
158
+ "gate_proj","up_proj","down_proj"],
159
+ lora_dropout = 0.0,
 
 
160
  )
161
+ tokenizer.pad_token = tokenizer.eos_token
162
+ print(f" trainable params: "
163
+ f"{sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
164
  return model, tokenizer
165
 
166
+ # ══════════════════════════════════════════════════════════════════════════════
167
+ # PROMPT BUILDER
168
+ # ══════════════════════════════════════════════════════════════════════════════
169
+ def build_prompt(obs, history_lines: List[str]) -> str:
170
+ author_msg = getattr(obs, "author_response", "") or ""
171
+ tool_output = getattr(obs, "last_tool_output", "") or ""
172
+ personality = getattr(obs, "author_personality","defensive")
173
+
174
+ # Trim tool output to avoid context explosion
175
+ if len(tool_output) > 600:
176
+ tool_output = tool_output[:600] + " …[truncated]"
177
+
178
+ p = (
179
+ f"You are an AI code review agent. Convince the developer (personality: "
180
+ f"**{personality}**) to accept your fix. Name your fix function `fix`.\n\n"
181
+ "Evidence required: tests pass, lint clean, docs cited, reasoning uses "
182
+ "'because'/'therefore' (>30 words).\n\n"
183
+ "Workflow: inspect β†’ run_tests β†’ run_linter β†’ query_docs β†’ fix β†’ "
184
+ "comment/question β†’ done.\n\n"
185
+ f"Code:\n{obs.code_snippet}\n\n"
186
+ f"Author: {author_msg or '(no response yet – start with inspect)'}\n\n"
187
+ f"Last tool: {tool_output or '(none)'}\n\n"
188
+ "Actions: run_tests, run_linter, inspect, query_docs, fix, comment, question, done\n\n"
189
+ 'Respond ONLY in JSON: {"action_type": "...", "content": "..."}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
+ if history_lines:
192
+ p += "\n\nRecent steps:\n" + "\n".join(history_lines[-4:])
193
+ return p
194
+
195
+ # ══════════════════════════════════════════════════════════════════════════════
196
+ # BUG FIX 1 – label masking in supervised warmup
197
+ # (original: labels=inputs["input_ids"] trains on ALL tokens, including prompt)
198
+ # ══════════════════════════════════════════════════════════════════════════════
199
+ def _masked_labels(input_ids: torch.Tensor, prompt_len: int) -> torch.Tensor:
200
+ """Return labels with prompt positions set to -100 (ignored by CE loss)."""
201
+ labels = input_ids.clone()
202
+ labels[0, :prompt_len] = -100
203
+ return labels
204
+
205
+ # ══════════════════════════════════════════════════════════════════════════════
206
+ # BUG FIX 2 – BPE-boundary-safe logprob computation
207
+ # (original: tokenize(prompt) + tokenize(action) β‰  tokenize(prompt+action))
208
+ # ══════════════════════════════════════════════════════════════════════════════
209
+ def _compute_action_logprob(
210
+ logits: torch.Tensor, # [1, seq_len, vocab]
211
+ input_ids: torch.Tensor, # [1, seq_len]
212
+ prompt_len: int, # #tokens in the prompt part of the joint sequence
213
+ ) -> tuple:
214
+ """
215
+ Compute sum of log-probs for *action* tokens only, using the jointly
216
+ tokenised sequence so BPE boundaries are respected.
217
+
218
+ Returns (total_logprob, avg_entropy, n_tokens).
219
+ """
220
+ action_len = input_ids.shape[1] - prompt_len
221
+ if action_len <= 0:
222
+ return torch.tensor(0.0, device=DEVICE), torch.tensor(0.0, device=DEVICE), 0
223
+
224
+ total_lp = torch.tensor(0.0, device=DEVICE)
225
+ total_ent = torch.tensor(0.0, device=DEVICE)
226
+
227
+ for k in range(action_len):
228
+ pos = prompt_len + k # position of the k-th action token
229
+ pred_pos = pos - 1 # logit at pred_pos predicts token at pos
230
+ if pred_pos < 0 or pred_pos >= logits.shape[1]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  continue
232
+ token_id = input_ids[0, pos]
233
+ lp_dist = F.log_softmax(logits[0, pred_pos], dim=-1)
234
+ total_lp = total_lp + lp_dist[token_id]
235
+ probs = torch.exp(lp_dist)
236
+ total_ent = total_ent + (-(probs * lp_dist).sum()).detach()
237
+
238
+ n = action_len
239
+ return total_lp, total_ent / max(n, 1), n
240
+
241
+ # ══════════════════════════════════════════════════════════════════════════════
242
+ # GENERATION (returns text + joint-sequence logprob)
243
+ # ══════════════════════════════════════════════════════════════════════════════
244
+ @torch.no_grad()
245
+ def generate_action(prompt: str, model, tokenizer,
246
+ temperature: float) -> tuple:
247
+ messages = [{"role": "user", "content": prompt}]
248
+ formatted = tokenizer.apply_chat_template(
249
+ messages, tokenize=False, add_generation_prompt=True
250
+ )
251
 
252
+ inputs = tokenizer(
253
+ formatted, return_tensors="pt",
254
+ max_length=CFG["max_seq_len"] - 128, # leave room for response
255
+ truncation=True
256
+ ).to(DEVICE)
257
+ prompt_len = inputs["input_ids"].shape[1]
258
+
259
+ gen_kwargs = dict(
260
+ max_new_tokens = 128,
261
+ do_sample = temperature > 0,
262
+ return_dict_in_generate = True,
263
+ output_scores = True,
264
+ pad_token_id = tokenizer.eos_token_id,
265
+ eos_token_id = tokenizer.eos_token_id,
266
+ )
267
+ if temperature > 0:
268
+ gen_kwargs["temperature"] = temperature
269
+
270
+ out = model.generate(**inputs, **gen_kwargs)
271
+ gen_ids = out.sequences[0][prompt_len:]
272
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
273
+
274
+ if not text:
275
+ fallback = random.choice([
276
+ '{"action_type":"inspect"}',
277
+ '{"action_type":"run_tests"}',
278
+ '{"action_type":"run_linter"}',
279
+ ])
280
+ print(f" [WARN] empty generation β†’ fallback {fallback}")
281
+ # BUG FIX 3: don't use -100 sentinel; use a mildly negative logprob
282
+ # so that PPO ratio = exp(new - old) stays finite when re-evaluated
283
+ return fallback, -10.0
284
+
285
+ # Recompute logprob from the full joint sequence (BPE-safe)
286
+ joint_ids = torch.cat(
287
+ [inputs["input_ids"], gen_ids.unsqueeze(0).to(DEVICE)], dim=1
288
+ )
289
+ joint_ids = joint_ids[:, :CFG["max_seq_len"]]
290
+
291
+ logits = model(input_ids=joint_ids).logits
292
+ lp, _, _ = _compute_action_logprob(logits, joint_ids, prompt_len)
293
+
294
+ return text, lp.item()
295
+
296
+ # ══════════════════════════════════════════════════════════════════════════════
297
+ # TRAJECTORY COLLECTION
298
+ # ══════════════════════════════════════════════════════════════════════════════
299
+ # Per-action shaped rewards. These create reward variance so that
300
+ # trajectories with meaningful tool use beat inspect-only episodes.
301
+ _STEP_REWARD = {
302
+ "run_tests": +0.08,
303
+ "run_linter": +0.05,
304
+ "fix": +0.15,
305
+ "comment": +0.08,
306
+ "query_docs": +0.05,
307
+ "question": +0.04,
308
+ "inspect": 0.00, # neutral – observe before acting
309
+ "done": 0.00, # env handles the terminal reward
310
+ "skip": -0.10, # penalise doing nothing
311
+ }
312
+
313
+ def collect_trajectory(env, model, tokenizer,
314
+ max_steps: int, temperature: float,
315
+ task: str) -> tuple:
316
+ """
317
+ FIX 4 – Override env done/reward for non-terminal actions.
318
 
319
+ Root cause of the degenerate policy:
320
+ β€’ env.step(Inspect()) returns done=True, reward=+0.002
321
+ β€’ agent discovers inspect β†’ tiny reward β†’ done is the easiest path
322
+ β€’ every trajectory is identical β†’ zero advantage β†’ PPO does nothing
323
+
324
+ Fix: only accept env's done+reward when the agent explicitly emits
325
+ {"action_type": "done"}. For every other action, use a shaped step
326
+ reward and force the episode to continue.
327
+ """
328
+ env.set_task(task)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  obs = env.reset()
330
+ history: List[str] = []
331
+ traj = Trajectory([], [], [], [], [], task=task)
332
+ action_seq = []
333
+
334
+ for step_num in range(max_steps):
335
+ prompt = build_prompt(obs, history)
336
+ traj.states.append(prompt)
337
+
338
+ text, lp = generate_action(prompt, model, tokenizer, temperature)
339
+ traj.actions.append(text)
340
+ traj.logprobs.append(lp)
341
+
342
+ action = parse_action(text)
343
+ action_seq.append(action.action_type)
344
+
345
+ obs, reward, env_done, _ = env.step(map_to_env(action))
346
+ raw_r = float(reward.value)
347
+
348
+ if action.action_type == "done":
349
+ # Agent explicitly chose to terminate β†’ honour env reward
350
+ shaped_r = raw_r
351
+ effective_done = True
352
+ else:
353
+ # Intermediate step: use shaped reward, ignore env's done signal.
354
+ # Also keep a fraction of any large env reward (e.g. test pass).
355
+ shaped_r = _STEP_REWARD.get(action.action_type, 0.0)
356
+ if raw_r > 0.1: # env signalling meaningful progress
357
+ shaped_r += raw_r * 0.3
358
+ effective_done = False # ← key: don't let env short-circuit
359
+
360
+ traj.rewards.append(float(np.clip(shaped_r, -1.0, 1.0)))
361
+ traj.dones.append(effective_done)
362
+
363
+ history.append(f"Agent: {text[:120]}")
364
+ history.append(f"Env: {(obs.last_tool_output or '')[:120]}")
365
+
366
+ if effective_done:
367
  break
368
+
369
+ return traj, action_seq
370
+
371
+ # ══════════════════════════════════════════════════════════════════════════════
372
+ # SUPERVISED WARM-UP (BUG FIX 1: action-only label masking)
373
+ # ══════════════════════════════════════════════════════════════════════════════
374
+ def supervised_warmup(model, tokenizer):
375
+ print("\n" + "="*60)
376
+ print("SUPERVISED WARM-UP")
377
+ print("="*60)
378
+
379
+ with open(CFG["warmup_data"], encoding="utf-8") as f:
380
+ data = json.load(f)
381
+
382
+ opt = AdamW(model.parameters(), lr=CFG["warmup_lr"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  model.train()
384
+ loss_history = []
385
+
386
+ for epoch in range(CFG["warmup_epochs"]):
387
+ random.shuffle(data)
388
+ epoch_loss, n_valid = 0.0, 0
389
+ opt.zero_grad()
390
+
391
+ for step, ex in enumerate(data):
392
+ # ── Tokenise prompt and full sequence jointly ────────────────
393
+ prompt_chat = tokenizer.apply_chat_template(
394
+ [{"role": "user", "content": ex["prompt"]}],
395
+ tokenize=False, add_generation_prompt=True
396
+ )
397
+ full_chat = tokenizer.apply_chat_template(
398
+ [{"role": "user", "content": ex["prompt"]},
399
+ {"role": "assistant", "content": ex["action"]}],
400
+ tokenize=False
401
+ )
402
+
403
+ prompt_ids = tokenizer(
404
+ prompt_chat, return_tensors="pt",
405
+ max_length=CFG["max_seq_len"], truncation=True
406
+ )["input_ids"]
407
+ full_inputs = tokenizer(
408
+ full_chat, return_tensors="pt",
409
+ max_length=CFG["max_seq_len"], truncation=True
410
+ ).to(DEVICE)
411
+
412
+ prompt_len = prompt_ids.shape[1]
413
+ if prompt_len >= full_inputs["input_ids"].shape[1]:
414
+ continue # action got truncated away
415
+
416
+ # BUG FIX 1 ── mask prompt tokens so loss is action-only
417
+ labels = _masked_labels(full_inputs["input_ids"], prompt_len)
418
+
419
+ out = model(**full_inputs, labels=labels)
420
+ loss = out.loss / CFG["warmup_grad_acc"]
421
+ loss.backward()
422
+
423
+ if (step + 1) % CFG["warmup_grad_acc"] == 0:
424
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
425
+ opt.step()
426
+ opt.zero_grad()
427
+
428
+ epoch_loss += loss.item() * CFG["warmup_grad_acc"]
429
+ n_valid += 1
430
+
431
+ if (step + 1) % 50 == 0:
432
+ print(f" epoch {epoch+1} step {step+1}/{len(data)}"
433
+ f" loss={epoch_loss/n_valid:.4f}")
434
+
435
+ avg = epoch_loss / max(n_valid, 1)
436
+ loss_history.append(avg)
437
+ print(f" Epoch {epoch+1} complete: avg_loss={avg:.4f}")
438
+
439
+ torch.cuda.empty_cache()
440
+ print(f"βœ“ Warm-up done. Loss: {' β†’ '.join(f'{l:.4f}' for l in loss_history)}\n")
441
+ return loss_history
442
+
443
+ # ══════════════════════════════════════════════════════════════════════════════
444
+ # EVALUATION (produces rich EvalSnapshot for comparison plots)
445
+ # ══════════════════════════════════════════════════════════════════════════════
446
+ @torch.no_grad()
447
+ def evaluate(env, model, tokenizer, label: str = "") -> EvalSnapshot:
448
+ model.eval()
449
+ per_task: Dict[str, List[float]] = defaultdict(list)
450
+ action_counter: Counter = Counter()
451
+ all_steps, all_success = [], []
452
+ traces = []
453
+
454
+ for ep in range(CFG["eval_episodes"]):
455
+ task = TASK_LEVELS[ep % len(TASK_LEVELS)]
456
+ traj, actions = collect_trajectory(
457
+ env, model, tokenizer, CFG["max_steps"], 0.0, task
458
+ )
459
+ ep_r = sum(traj.rewards)
460
+ per_task[task].append(ep_r)
461
+ action_counter.update(actions)
462
+ all_steps.append(len(traj.actions))
463
+ # FIX 6 – meaningful success = agent explicitly called "done".
464
+ # ep_r > 0 is misleading: even a single inspect returns +0.002.
465
+ all_success.append(1 if "done" in actions else 0)
466
+ traces.append({"task": task, "reward": round(ep_r, 4),
467
+ "steps": len(traj.actions), "actions": actions})
468
+
469
+ total_actions = max(sum(action_counter.values()), 1)
470
+ snap = EvalSnapshot(
471
+ avg_reward = float(np.mean([r for rs in per_task.values() for r in rs])),
472
+ per_task = {t: float(np.mean(rs)) for t, rs in per_task.items()},
473
+ action_dist = {a: c/total_actions for a, c in action_counter.most_common()},
474
+ success_rate = float(np.mean(all_success)),
475
+ avg_steps = float(np.mean(all_steps)),
476
+ traces = traces,
477
+ )
478
+ if label:
479
+ print(f"\n── {label} ──")
480
+ print(f" avg_reward={snap.avg_reward:+.4f} "
481
+ f"success={snap.success_rate:.0%} steps={snap.avg_steps:.1f}")
482
+ print(f" per-task: " +
483
+ " ".join(f"{t}={v:+.3f}" for t,v in snap.per_task.items()))
484
+ print(f" top actions: " +
485
+ " ".join(f"{a}={p:.0%}" for a,p in list(snap.action_dist.items())[:5]))
486
+ model.train()
487
+ return snap
488
+
489
+ # ══════════════════════════════════════════════════════════════════════════════
490
+ # PPO UPDATE (BUG FIX 2 + 3: BPE-safe logprob + log-ratio clamping)
491
+ # ══════════════════════════════════════════════════════════════════════════════
492
+ def ppo_update(trajectories: List[Trajectory],
493
+ model, tokenizer, optimizer) -> dict:
494
+ model.train()
495
+ losses, kls, entropies = [], [], []
496
+
497
+ # ── Compute discounted returns and a global mean baseline ────────────────
498
  all_returns = []
499
+ traj_returns = []
500
  for traj in trajectories:
501
+ ret, running = [], 0.0
502
+ for r, done in zip(reversed(traj.rewards), reversed(traj.dones)):
503
+ running = r + CFG["gamma"] * (0.0 if done else running)
504
+ ret.insert(0, running)
505
+ traj_returns.append(ret)
506
+ all_returns.extend(ret)
507
+
508
+ # FIX 5 – Normalise advantages to zero mean / unit std.
509
+ # When all returns are identical (e.g. every episode returns 0.002),
510
+ # baseline = mean = every return, so adv = 0 for all steps, the
511
+ # policy loss is 0, and PPO never updates. Normalising creates real
512
+ # signal: better-than-average trajectories get positive advantage,
513
+ # worse-than-average get negative, even if the absolute spread is tiny.
514
+ ret_arr = np.array(all_returns) if all_returns else np.array([0.0])
515
+ ret_mean = float(ret_arr.mean())
516
+ ret_std = float(ret_arr.std())
517
+
518
+ if ret_std < 1e-6:
519
+ # Truly zero variance – nothing to learn this iteration.
520
+ print(" [PPO] Zero return variance – skipping gradient update.")
521
+ return dict(loss=0.0, kl=0.0, entropy=0.0)
522
+
523
+ # Build a lookup so we can retrieve the normalised advantage by
524
+ # (trajectory index, step index) during the update loop below.
525
+ norm_returns: List[List[float]] = [
526
+ [(r - ret_mean) / (ret_std + 1e-8) for r in ret_list]
527
+ for ret_list in traj_returns
528
+ ]
529
+
530
+ for traj_idx, (traj, returns) in enumerate(zip(trajectories, traj_returns)):
531
+ for i in range(len(traj.states)):
532
+ state = traj.states[i]
533
+ action = traj.actions[i]
534
+ old_lp = traj.logprobs[i]
535
+ adv = norm_returns[traj_idx][i] # ← normalised advantage
536
+
537
+ # ── Tokenise jointly (BPE FIX 2) ────────────────────────────────
538
+ prompt_chat = tokenizer.apply_chat_template(
539
+ [{"role": "user", "content": state}],
540
+ tokenize=False, add_generation_prompt=True
541
+ )
542
+ full_text = prompt_chat + action
543
+
544
+ full_ids = tokenizer(
545
+ full_text, return_tensors="pt",
546
+ max_length=CFG["max_seq_len"], truncation=True
547
+ ).to(DEVICE)
548
+
549
+ # Count prompt tokens IN THE JOINT SEQUENCE (not separately)
550
+ prompt_ids = tokenizer(
551
+ prompt_chat, return_tensors="pt",
552
+ max_length=CFG["max_seq_len"] - 10, truncation=True
553
+ )["input_ids"]
554
+ prompt_len = min(prompt_ids.shape[1], full_ids["input_ids"].shape[1] - 1)
555
+
556
+ logits = model(**full_ids).logits
557
+
558
+ new_lp, avg_ent, n_tokens = _compute_action_logprob(
559
+ logits, full_ids["input_ids"], prompt_len
560
+ )
561
+ if n_tokens == 0:
562
  continue
563
+
564
+ # BUG FIX 3 ── clamp log-ratio before exp to prevent NaN
565
+ old_lp_t = torch.tensor(old_lp, dtype=torch.float32, device=DEVICE)
566
+ log_ratio = torch.clamp(new_lp - old_lp_t,
567
+ -CFG["log_ratio_clamp"],
568
+ CFG["log_ratio_clamp"])
569
+ ratio = torch.exp(log_ratio)
570
+
571
+ adv_t = torch.tensor(adv, dtype=torch.float32, device=DEVICE)
572
+ s1 = ratio * adv_t
573
+ s2 = torch.clamp(ratio,
574
+ 1.0 - CFG["clip_eps"],
575
+ 1.0 + CFG["clip_eps"]) * adv_t
576
+
577
+ policy_loss = -torch.min(s1, s2)
578
+ loss = policy_loss - CFG["entropy_coef"] * avg_ent
579
+
580
+ if torch.isnan(loss) or torch.isinf(loss):
581
+ continue
582
+
583
  optimizer.zero_grad()
584
  loss.backward()
585
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
586
  optimizer.step()
587
+
588
+ losses.append(loss.item())
589
+ kls.append((old_lp_t - new_lp).detach().cpu().item())
590
+ entropies.append(avg_ent.item())
591
+
592
+ torch.cuda.empty_cache()
593
+ return dict(
594
+ loss = float(np.mean(losses)) if losses else 0.0,
595
+ kl = float(np.mean(kls)) if kls else 0.0,
596
+ entropy = float(np.mean(entropies)) if entropies else 0.0,
597
+ )
598
+
599
+ # ══════════════════════════════════════════════════════════════════════════════
600
+ # PLOTTING (rich evidence panel)
601
+ # ══════════════════════════════════════════════════════════════════════════════
602
+ def plot_all(warmup_losses, reward_hist, success_hist, kl_hist, entropy_hist,
603
+ baseline_snap: EvalSnapshot,
604
+ postwarmup_snap: EvalSnapshot,
605
+ final_snap: EvalSnapshot):
606
+
607
+ iters = list(range(1, len(reward_hist) + 1))
608
+
609
+ # ── Figure 1: training curves (2Γ—3 grid) ─────────────────────────────────
610
+ fig = plt.figure(figsize=(18, 10))
611
+ gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
612
+
613
+ # (0,0) Warm-up loss
614
+ ax = fig.add_subplot(gs[0, 0])
615
+ ax.plot(range(1, len(warmup_losses)+1), warmup_losses,
616
+ marker="o", color="mediumpurple", linewidth=2)
617
+ ax.set_title("A. Warm-up CE Loss ↓", fontweight="bold")
618
+ ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.grid(alpha=0.3)
619
+
620
+ # (0,1) PPO reward
621
+ ax = fig.add_subplot(gs[0, 1])
622
+ smooth = np.convolve(reward_hist, np.ones(3)/3, mode="same")
623
+ ax.plot(iters, reward_hist, alpha=0.35, color="steelblue", linewidth=1)
624
+ ax.plot(iters, smooth, color="steelblue", linewidth=2.5, label="reward (smoothed)")
625
+ ax.axhline(baseline_snap.avg_reward, color="gray", linestyle=":",
626
+ label=f"pre-warmup ({baseline_snap.avg_reward:+.3f})")
627
+ ax.axhline(postwarmup_snap.avg_reward, color="mediumpurple", linestyle="--",
628
+ label=f"post-warmup ({postwarmup_snap.avg_reward:+.3f})")
629
+ ax.axhline(final_snap.avg_reward, color="forestgreen", linestyle="-.",
630
+ label=f"final ({final_snap.avg_reward:+.3f})")
631
+ ax.set_title("B. PPO Reward ↑", fontweight="bold")
632
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Avg Reward")
633
+ ax.legend(fontsize=7); ax.grid(alpha=0.3)
634
+
635
+ # (0,2) Success rate
636
+ ax = fig.add_subplot(gs[0, 2])
637
+ ax.plot(iters, success_hist, marker="s", color="seagreen", linewidth=2)
638
+ ax.set_ylim(0, 1)
639
+ ax.set_title("C. Episode Success Rate ↑", fontweight="bold")
640
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Fraction")
641
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y,_: f"{y:.0%}"))
642
+ ax.grid(alpha=0.3)
643
+
644
+ # (1,0) KL divergence
645
+ ax = fig.add_subplot(gs[1, 0])
646
+ ax.plot(iters, kl_hist, marker="^", color="tomato", linewidth=2)
647
+ ax.axhline(0, color="gray", linewidth=0.8)
648
+ ax.set_title("D. KL Divergence", fontweight="bold")
649
+ ax.set_xlabel("Iteration"); ax.set_ylabel("KL"); ax.grid(alpha=0.3)
650
+
651
+ # (1,1) Entropy
652
+ ax = fig.add_subplot(gs[1, 1])
653
+ ax.plot(iters, entropy_hist, marker="D", color="darkorange", linewidth=2)
654
+ ax.set_title("E. Policy Entropy", fontweight="bold")
655
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Entropy"); ax.grid(alpha=0.3)
656
+
657
+ # (1,2) Per-difficulty final reward
658
+ ax = fig.add_subplot(gs[1, 2])
659
+ tasks = TASK_LEVELS
660
+ vals_base = [baseline_snap.per_task.get(t, 0) for t in tasks]
661
+ vals_final = [final_snap.per_task.get(t, 0) for t in tasks]
662
+ x = np.arange(len(tasks))
663
+ ax.bar(x - 0.2, vals_base, 0.35, label="baseline",color="lightcoral", alpha=0.8)
664
+ ax.bar(x + 0.2, vals_final, 0.35, label="final", color="steelblue", alpha=0.8)
665
+ ax.set_xticks(x); ax.set_xticklabels(tasks, fontsize=8)
666
+ ax.set_title("F. Per-Difficulty Reward", fontweight="bold")
667
+ ax.set_ylabel("Avg Reward"); ax.legend(fontsize=8); ax.grid(alpha=0.3, axis="y")
668
+ ax.axhline(0, color="gray", linewidth=0.8)
669
+
670
+ fig.suptitle(f"Code-Review Agent – Full Training Evidence "
671
+ f"(Qwen2.5-1.5B, PPO + QLoRA)",
672
+ fontsize=13, fontweight="bold")
673
+ fig.savefig("training_summary.png", dpi=150, bbox_inches="tight")
674
+ plt.close(fig)
675
+ print(" Saved: training_summary.png")
676
+
677
+ # ── Figure 2: before / after action distribution ─────────────────────────
678
+ fig, axes = plt.subplots(1, 3, figsize=(16, 4), sharey=False)
679
+ for ax, snap, title in zip(
680
+ axes,
681
+ [baseline_snap, postwarmup_snap, final_snap],
682
+ ["Before (baseline)", "After warm-up", "After PPO (final)"]
683
+ ):
684
+ if snap.action_dist:
685
+ labels = list(snap.action_dist.keys())
686
+ vals = [snap.action_dist[l]*100 for l in labels]
687
+ bars = ax.barh(labels, vals,
688
+ color=plt.cm.tab10(np.linspace(0, 0.8, len(labels))))
689
+ ax.bar_label(bars, fmt="%.0f%%", padding=3, fontsize=8)
690
+ ax.set_xlim(0, 105)
691
+ ax.set_title(title, fontweight="bold")
692
+ ax.set_xlabel("% of actions")
693
+ ax.grid(alpha=0.3, axis="x")
694
+
695
+ fig.suptitle("Action Distribution: Before vs After Training",
696
+ fontsize=12, fontweight="bold")
697
+ plt.tight_layout()
698
+ fig.savefig("action_distribution.png", dpi=150, bbox_inches="tight")
699
+ plt.close(fig)
700
+ print(" Saved: action_distribution.png")
701
+
702
+ # ══════════════════════════════════════════════════════════════════════════════
703
+ # MAIN
704
+ # ══════════════════════════════════════════════════════════════════════════════
705
+ def train():
706
  model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
 
707
  env = CodeReviewEnv()
708
+
709
+ # ── PHASE 0: pre-warmup baseline ────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
  print("\n" + "="*60)
711
+ print("PHASE 0 – BASELINE (untrained)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  print("="*60)
713
+ baseline_snap = evaluate(env, model, tokenizer, "Baseline")
714
+
715
+ # ── PHASE 1: supervised warm-up ─────────────────────────────────────────
716
+ warmup_losses = supervised_warmup(model, tokenizer)
717
+
718
+ postwarmup_snap = evaluate(env, model, tokenizer, "Post-Warmup")
719
+
720
+ # ── PHASE 2: PPO ────────────────────────────────────────────────────────
721
+ optimizer = AdamW(model.parameters(), lr=CFG["ppo_lr"])
722
+ reward_hist, success_hist, kl_hist, entropy_hist = [], [], [], []
723
+
724
+ print("\n" + "="*60)
725
+ print(f"PHASE 2 – PPO ({CFG['ppo_iters']} iterations Γ— "
726
+ f"{CFG['trajs_per_iter']} trajectories)")
727
+ print("="*60)
728
+
729
+ for it in range(CFG["ppo_iters"]):
730
+ # Linearly anneal exploration temperature
731
+ # FIX 7 – exponential decay with a floor (never below 0.35).
732
+ # Linear annealing to 0.1 collapses exploration before we learn
733
+ # anything; keeping >= 0.35 ensures trajectory diversity.
734
+ t = max(CFG["temp_start"] * (0.93 ** it), 0.35)
735
+
736
+ print(f"\n── Iteration {it+1}/{CFG['ppo_iters']} temp={t:.2f} ──")
737
+ trajectories, action_counts = [], Counter()
738
+ successes = 0
739
+
740
+ for j in range(CFG["trajs_per_iter"]):
741
+ task = TASK_LEVELS[j % len(TASK_LEVELS)]
742
+ traj, actions = collect_trajectory(
743
+ env, model, tokenizer, CFG["max_steps"], t, task
744
+ )
745
+ trajectories.append(traj)
746
+ action_counts.update(actions)
747
+ ep_r = sum(traj.rewards)
748
+ # FIX 6b – consistent with evaluate(): only explicit done counts
749
+ successes += int("done" in actions)
750
+ print(f" traj {j+1}/{CFG['trajs_per_iter']} task={task}"
751
+ f" steps={len(traj.actions)} reward={ep_r:+.3f}")
752
+
753
+ avg_r = float(np.mean([sum(t.rewards) for t in trajectories]))
754
+ success_r = successes / CFG["trajs_per_iter"]
755
+
756
+ m = ppo_update(trajectories, model, tokenizer, optimizer)
757
+
758
+ reward_hist.append(avg_r)
759
+ success_hist.append(success_r)
760
+ kl_hist.append(m["kl"])
761
+ entropy_hist.append(m["entropy"])
762
+
763
+ delta = avg_r - baseline_snap.avg_reward
764
+ print(f" β†’ avg_reward={avg_r:+.4f} Ξ”baseline={delta:+.4f}"
765
+ f" success={success_r:.0%}"
766
+ f" loss={m['loss']:.4f} kl={m['kl']:.4f} ent={m['entropy']:.4f}")
767
+ print(f" actions: {dict(action_counts.most_common(5))}")
768
+
769
+ # ── PHASE 3: final evaluation ───────────────────────────────────────────
770
+ print("\n" + "="*60)
771
+ print("PHASE 3 – FINAL EVALUATION")
772
+ print("="*60)
773
+ final_snap = evaluate(env, model, tokenizer, "Final")
774
+
775
+ # ── Summary table ───────────────────────────────────────────────────────
776
+ print("\n" + "="*60)
777
+ print("TRAINING SUMMARY")
778
+ print("="*60)
779
+ print(f" {'Stage':<20} {'Reward':>10} {'Success':>10} {'Ξ” baseline':>12}")
780
+ print(f" {'-'*54}")
781
+ for label, snap in [("Baseline", baseline_snap),
782
+ ("Post-warmup", postwarmup_snap),
783
+ ("Final (PPO)", final_snap)]:
784
+ delta = snap.avg_reward - baseline_snap.avg_reward
785
+ print(f" {label:<20} {snap.avg_reward:>+10.4f}"
786
+ f" {snap.success_rate:>10.0%} {delta:>+11.4f}")
787
+
788
+ improve = final_snap.avg_reward - baseline_snap.avg_reward
789
+ verdict = "βœ“ LEARNED" if improve > 0 else "βœ— NO IMPROVEMENT"
790
+ print(f"\n {verdict} (total Ξ” = {improve:+.4f})")
791
+
792
+ print("\nBefore β†’ After traces (one per difficulty):")
793
+ btask = {t["task"]: t for t in baseline_snap.traces}
794
+ ftask = {t["task"]: t for t in final_snap.traces}
795
+ for task in TASK_LEVELS:
796
+ b = btask.get(task, {})
797
+ f = ftask.get(task, {})
798
+ print(f" {task:8s} baseline actions={b.get('actions',[])} "
799
+ f"reward={b.get('reward',0):+.3f}"
800
+ f" β”‚ final actions={f.get('actions',[])} "
801
+ f"reward={f.get('reward',0):+.3f}")
802
+
803
+ # ── Plots ───────────────────────────────────────────────────────────────
804
+ plot_all(warmup_losses, reward_hist, success_hist, kl_hist, entropy_hist,
805
+ baseline_snap, postwarmup_snap, final_snap)
806
+
807
+ print("\nAll done. Saved: training_summary.png action_distribution.png")
808
+
809
 
 
 
 
810
  if __name__ == "__main__":
811
+ train()