Ajitg25 commited on
Commit
cbbeb6e
·
verified ·
1 Parent(s): 56caa44

Remove train.py: serving Space only

Browse files
Files changed (1) hide show
  1. train.py +0 -397
train.py DELETED
@@ -1,397 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Ambulance Green Corridor — GRPO Training on HF Space (A10G GPU).
4
-
5
- This script:
6
- 1. Starts the ambulance server as a background process
7
- 2. Loads Qwen2.5-0.5B-Instruct via Unsloth (4-bit LoRA)
8
- 3. Runs GRPO training for 60 iterations
9
- 4. Saves plots to /app/plots/
10
- 5. Exits — the Dockerfile then starts the serving mode
11
- """
12
-
13
- import asyncio
14
- import json
15
- import os
16
- import re
17
- import subprocess
18
- import sys
19
- import time
20
- import warnings
21
- from pathlib import Path
22
-
23
- warnings.filterwarnings("ignore", message=".*max_new_tokens.*")
24
- warnings.filterwarnings("ignore", category=FutureWarning)
25
-
26
- import matplotlib
27
-
28
- matplotlib.use("Agg")
29
- import matplotlib.pyplot as plt
30
- import nest_asyncio
31
- import numpy as np
32
- import torch
33
- import torch.nn.functional as F
34
- from torch.optim import AdamW
35
- from transformers import AutoModelForCausalLM, AutoTokenizer
36
- from peft import get_peft_model, LoraConfig, TaskType
37
-
38
- nest_asyncio.apply()
39
-
40
- # Support both Docker (/app) and Kaggle (/kaggle/working/repo) paths
41
- _REPO_ROOT = os.environ.get("REPO_ROOT", "/app")
42
- _ENVS_PATH = os.path.join(_REPO_ROOT, "envs")
43
- sys.path.insert(0, _ENVS_PATH)
44
- from ambulance_env import AmbulanceEnv
45
- from ambulance_env.models import AmbulanceAction, SignalControl
46
-
47
- # ── Config ──────────────────────────────────────────────────────────────────
48
- ENV_URL = "http://localhost:8000"
49
- DIFFICULTY = os.getenv("AMBULANCE_DIFFICULTY", "easy")
50
- MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
51
- MAX_SEQ_LEN = 1024
52
- NUM_ITERATIONS = 10
53
- GROUP_SIZE = 4
54
- BETA_KL = 0.01
55
- LR = 5e-5
56
- PLOT_DIR = Path(os.environ.get("PLOT_DIR", "/app/plots"))
57
- PLOT_DIR.mkdir(parents=True, exist_ok=True)
58
-
59
- # ── 1. Start server ────────────────────────────────────────────────────────
60
- print("Starting ambulance_env server...")
61
- server_proc = subprocess.Popen(
62
- [sys.executable, "-m", "uvicorn", "ambulance_env.server.app:app",
63
- "--host", "0.0.0.0", "--port", "8000", "--log-level", "error"],
64
- env={**os.environ, "PYTHONPATH": _ENVS_PATH, "AMBULANCE_DIFFICULTY": DIFFICULTY},
65
- )
66
- time.sleep(4)
67
- print("Server ready.")
68
-
69
-
70
- # ── 2. Load model ──────────────────────────────────────────────────────────
71
- print(f"Loading {MODEL_NAME}...")
72
- base_model = AutoModelForCausalLM.from_pretrained(
73
- MODEL_NAME, dtype=torch.float16, device_map="auto",
74
- )
75
- print("Base model loaded.")
76
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
77
- tokenizer.padding_side = "left"
78
- if tokenizer.pad_token is None:
79
- tokenizer.pad_token = tokenizer.eos_token
80
- print("Tokenizer loaded.")
81
-
82
- lora_config = LoraConfig(
83
- r=16, lora_alpha=16, lora_dropout=0,
84
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
85
- task_type=TaskType.CAUSAL_LM,
86
- )
87
- model = get_peft_model(base_model, lora_config)
88
- model.gradient_checkpointing_enable()
89
- print(f"LoRA ready. Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
90
-
91
-
92
- # ── 3. Prompt formatters ──────────────────────────────────────────────────
93
- SYSTEM_PROMPT = (
94
- "You are an emergency services AI managing an ambulance in a real city.\n"
95
- "You must:\n"
96
- " 1. Choose the best hospital (consider specialization, distance, traffic, road quality)\n"
97
- " 2. Clear traffic signals ahead — but ONLY signals in the WRONG phase\n"
98
- " 3. Re-route if traffic spikes, accidents, or road closures appear\n"
99
- " 4. Switch hospitals mid-journey if an alternative becomes significantly faster\n\n"
100
- "Heavy traffic slows the ambulance even on green. Potholed roads force slow speeds.\n"
101
- "Be precise. Follow the output format exactly."
102
- )
103
-
104
-
105
- def _quality_label(q):
106
- if q >= 0.9: return "highway"
107
- if q >= 0.65: return "good"
108
- if q >= 0.4: return "moderate"
109
- return "POTHOLED"
110
-
111
-
112
- def format_prompt(obs):
113
- lines = [
114
- "=== EMERGENCY DISPATCH ===",
115
- f"Patient : {obs.patient_location} | condition: {obs.patient_condition}",
116
- f"Ambulance: {obs.ambulance_location} | time: {obs.time_elapsed_seconds:.0f}s / {obs.time_limit_seconds:.0f}s",
117
- "",
118
- ]
119
- if obs.active_events:
120
- lines.append("DYNAMIC EVENTS:")
121
- for e in obs.active_events:
122
- lines.append(f" [{e.event_type.upper()}] at {e.position} — {e.description}")
123
- lines.append("")
124
- if obs.target_hospital_id:
125
- r = obs.current_route
126
- lines.append(f"CURRENT ROUTE → {obs.target_hospital_id}")
127
- lines.append(f" ETA={r.estimated_time:.0f}s | segs={len(r.segments)} | damaged={r.num_damaged_segments} | heavy={r.num_heavy_traffic_segments}")
128
- for seg in r.segments[:4]:
129
- lines.append(f" {seg.from_pos}→{seg.to_pos} | {seg.road_type} | quality={_quality_label(seg.road_quality)} | traffic={seg.traffic_volume:.0%} | est={seg.estimated_transit_time:.0f}s" + (" [BLOCKED]" if seg.blocked else ""))
130
- lines.append("")
131
- if obs.alternative_routes:
132
- lines.append("ALTERNATIVES:")
133
- for alt in obs.alternative_routes:
134
- hosp = next((h for h in obs.hospitals if h.hospital_id == alt.hospital_id), None)
135
- spec = hosp.specialization if hosp else "?"
136
- match = " <- specialist" if hosp and hosp.specialization == obs.patient_condition else ""
137
- lines.append(f" {alt.hospital_id} ({spec}){match}: ETA={alt.estimated_time:.0f}s | damaged={alt.num_damaged_segments}")
138
- lines.append("")
139
- lines.append("HOSPITALS:")
140
- for h in obs.hospitals:
141
- cap = " [FULL]" if h.at_capacity else ""
142
- match = " <- specialist" if h.specialization == obs.patient_condition else ""
143
- lines.append(f" {h.hospital_id}: {h.name} | spec={h.specialization} | est={h.travel_time_estimate:.0f}s{cap}{match}")
144
- lines.append("")
145
- if obs.lookahead_signals:
146
- lines.append("SIGNALS (only change WRONG):")
147
- for s in obs.lookahead_signals:
148
- needed = "ns_green" if s.ambulance_direction in ("north", "south") else "ew_green"
149
- status = "OK" if s.phase == needed else f"WRONG — needs {needed}"
150
- lines.append(f" ({s.row},{s.col}): {s.phase} | dir={s.ambulance_direction} | {status}")
151
- lines.append("")
152
- if obs.current_segment:
153
- lines.append(f"ROAD: {obs.current_segment.road_type} | quality={_quality_label(obs.current_segment.road_quality)} | traffic={obs.current_segment.traffic_volume:.0%} | speed={obs.last_speed_factor:.0%}")
154
- lines.append("")
155
- lines.append(f"STATS: stops={obs.stops_at_red} | efficiency={obs.signal_efficiency:.0%} | wasted={obs.unnecessary_toggles}")
156
- lines.append("")
157
- if not obs.target_hospital_id:
158
- lines.append('ACTION: {"hospital_id": "hosp_X", "signal_controls": [], "preferred_direction": null}')
159
- else:
160
- lines.append('ACTION: {"hospital_id": null, "signal_controls": [{"row": R, "col": C, "phase": "..."}], "preferred_direction": null}')
161
- return "\n".join(lines)
162
-
163
-
164
- def build_chat(obs):
165
- msgs = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": format_prompt(obs)}]
166
- return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
167
-
168
-
169
- # ── 4. Action parser ─────────────────────────────────────────────────────
170
- def parse_action(response_text, obs):
171
- text = response_text.strip()
172
- try:
173
- m = re.search(r"\{.*\}", text, re.DOTALL)
174
- if m:
175
- data = json.loads(m.group())
176
- hid = data.get("hospital_id")
177
- if hid:
178
- valid = {h.hospital_id for h in obs.hospitals if not h.at_capacity}
179
- if hid not in valid:
180
- hid = None
181
- controls = [
182
- SignalControl(row=int(c["row"]), col=int(c["col"]), phase=c["phase"])
183
- for c in data.get("signal_controls", [])
184
- if isinstance(c, dict) and c.get("phase") in ("ns_green", "ew_green")
185
- ]
186
- d = data.get("preferred_direction")
187
- if d not in ("north", "south", "east", "west"):
188
- d = None
189
- return AmbulanceAction(hospital_id=hid, signal_controls=controls, preferred_direction=d)
190
- except (json.JSONDecodeError, KeyError, ValueError, TypeError):
191
- pass
192
- if not obs.target_hospital_id:
193
- available = [h for h in obs.hospitals if not h.at_capacity]
194
- specs = [h for h in available if h.specialization == obs.patient_condition]
195
- pool = specs if specs else available
196
- if pool:
197
- return AmbulanceAction(hospital_id=min(pool, key=lambda h: h.travel_time_estimate).hospital_id)
198
- controls = [
199
- SignalControl(row=s.row, col=s.col, phase="ns_green" if s.ambulance_direction in ("north", "south") else "ew_green")
200
- for s in obs.lookahead_signals
201
- if s.phase != ("ns_green" if s.ambulance_direction in ("north", "south") else "ew_green")
202
- ]
203
- return AmbulanceAction(signal_controls=controls)
204
-
205
-
206
- # ── 5. Episode rollout ───────────────────────────────────────────────────
207
- @torch.no_grad()
208
- async def collect_episode_async(temperature=0.8, max_new_tokens=256):
209
- env = AmbulanceEnv(base_url=ENV_URL)
210
- steps = []
211
- try:
212
- result = await env.reset()
213
- obs = result.observation
214
- while not result.done:
215
- prompt = build_chat(obs)
216
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
217
- output_ids = model.generate(
218
- **inputs, max_new_tokens=max_new_tokens,
219
- temperature=temperature, do_sample=True,
220
- pad_token_id=tokenizer.eos_token_id,
221
- )
222
- new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
223
- response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
224
- action = parse_action(response_text, obs)
225
- result = await env.step(action)
226
- obs = result.observation
227
- steps.append({"prompt": prompt, "response": response_text, "step_reward": float(result.reward or 0.0)})
228
- total = sum(s["step_reward"] for s in steps)
229
- for s in steps:
230
- s["episode_reward"] = total
231
- state = await env.state()
232
- return steps, state
233
- finally:
234
- await env.close()
235
-
236
-
237
- def collect_episode(temperature=0.8, max_new_tokens=256):
238
- loop = asyncio.get_event_loop()
239
- return loop.run_until_complete(collect_episode_async(temperature, max_new_tokens))
240
-
241
-
242
- # ── 6. Evaluate ──────────────────────────────────────────────────────────
243
- def evaluate(num_episodes=8):
244
- rewards, arrivals, effs, times, reroutes = [], [], [], [], []
245
- for _ in range(num_episodes):
246
- steps, state = collect_episode(temperature=0.1)
247
- rewards.append(steps[-1]["episode_reward"] if steps else 0.0)
248
- arrivals.append(float(state.success))
249
- effs.append(state.signal_efficiency)
250
- times.append(state.arrival_time or 999.0)
251
- reroutes.append(state.successful_reroutes)
252
- return {
253
- "mean_reward": float(np.mean(rewards)),
254
- "arrival_rate": float(np.mean(arrivals)),
255
- "mean_efficiency": float(np.mean(effs)),
256
- "mean_time": float(np.mean(times)),
257
- "mean_reroutes": float(np.mean(reroutes)),
258
- }
259
-
260
-
261
- # ── 7. Baseline ──────────────────────────────────────────────────────────
262
- print("\n=== Baseline evaluation ===")
263
- print("Running episode 1 of 2...")
264
- baseline = evaluate(num_episodes=2)
265
- print(f"BASELINE reward={baseline['mean_reward']:.1f} arrival={baseline['arrival_rate']:.0%} "
266
- f"efficiency={baseline['mean_efficiency']:.0%} reroutes={baseline['mean_reroutes']:.1f} "
267
- f"time={baseline['mean_time']:.0f}s")
268
-
269
-
270
- # ── 8. GRPO Training ────────────────────────────────────────────────────
271
- print(f"\n=== GRPO training: {NUM_ITERATIONS} iterations x {GROUP_SIZE} episodes ===\n")
272
-
273
- optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=LR, weight_decay=0.01)
274
- history = {"iteration": [], "mean_reward": [], "arrival_rate": [], "signal_efficiency": [], "mean_time": [], "mean_reroutes": []}
275
-
276
- for iteration in range(NUM_ITERATIONS):
277
- model.eval()
278
- group_steps, group_states = [], []
279
- for _ in range(GROUP_SIZE):
280
- steps, state = collect_episode(temperature=0.8)
281
- group_steps.append(steps)
282
- group_states.append(state)
283
-
284
- episode_rewards = [s[-1]["episode_reward"] if s else 0.0 for s in group_steps]
285
- r_tensor = torch.tensor(episode_rewards)
286
- advantages = (r_tensor - r_tensor.mean()) / (r_tensor.std() + 1e-8)
287
-
288
- model.train()
289
- iter_loss, num_updates = 0.0, 0
290
- for steps, adv in zip(group_steps, advantages.tolist()):
291
- for step in steps:
292
- prompt_ids = tokenizer(step["prompt"], return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN - 256).input_ids.to(model.device)
293
- response_ids = tokenizer(step["response"], return_tensors="pt", truncation=True, max_length=256).input_ids.to(model.device)
294
- if response_ids.shape[1] == 0:
295
- continue
296
- full_ids = torch.cat([prompt_ids, response_ids], dim=1)
297
- with torch.amp.autocast("cuda", dtype=torch.float16):
298
- logits = model(full_ids).logits
299
- resp_logits = logits[:, prompt_ids.shape[1] - 1 : -1, :]
300
- log_probs = F.log_softmax(resp_logits, dim=-1)
301
- token_lp = log_probs.gather(2, response_ids.unsqueeze(-1)).squeeze(-1)
302
- mean_lp = token_lp.mean()
303
- loss = -adv * mean_lp + BETA_KL * (mean_lp ** 2)
304
- optimizer.zero_grad()
305
- loss.backward()
306
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
307
- optimizer.step()
308
- iter_loss += loss.item()
309
- num_updates += 1
310
-
311
- mr = float(np.mean(episode_rewards))
312
- ar = float(np.mean([s.success for s in group_states]))
313
- me = float(np.mean([s.signal_efficiency for s in group_states]))
314
- mt = float(np.mean([s.arrival_time or 999.0 for s in group_states]))
315
- mrr = float(np.mean([s.successful_reroutes for s in group_states]))
316
-
317
- history["iteration"].append(iteration + 1)
318
- history["mean_reward"].append(mr)
319
- history["arrival_rate"].append(ar)
320
- history["signal_efficiency"].append(me)
321
- history["mean_time"].append(mt)
322
- history["mean_reroutes"].append(mrr)
323
-
324
- print(f"[{iteration+1:3d}/{NUM_ITERATIONS}] reward={mr:7.1f} arrival={ar:.0%} "
325
- f"efficiency={me:.0%} reroutes={mrr:.1f} time={mt:5.0f}s "
326
- f"loss={iter_loss / max(1, num_updates):.4f}")
327
-
328
-
329
- # ── 9. Final eval ───────────────────────────────────────────────────────
330
- print("\n=== Final evaluation ===")
331
- final = evaluate(num_episodes=8)
332
- print(f"FINAL reward={final['mean_reward']:.1f} arrival={final['arrival_rate']:.0%} "
333
- f"efficiency={final['mean_efficiency']:.0%} reroutes={final['mean_reroutes']:.1f} "
334
- f"time={final['mean_time']:.0f}s")
335
-
336
- print("\n── Improvement ──────────────────────────────────────────")
337
- print(f" Reward : {baseline['mean_reward']:6.1f} → {final['mean_reward']:6.1f} ({final['mean_reward']-baseline['mean_reward']:+.1f})")
338
- print(f" Arrival : {baseline['arrival_rate']:.0%} → {final['arrival_rate']:.0%}")
339
- print(f" Efficiency : {baseline['mean_efficiency']:.0%} → {final['mean_efficiency']:.0%}")
340
- print(f" Reroutes : {baseline['mean_reroutes']:.1f} → {final['mean_reroutes']:.1f}")
341
- print(f" Travel time : {baseline['mean_time']:.0f}s → {final['mean_time']:.0f}s ({final['mean_time']-baseline['mean_time']:+.0f}s)")
342
-
343
-
344
- # ── 10. Plots ────────────────────────────────────────────────────────────
345
- def smooth(values, window=5):
346
- if len(values) < window:
347
- return np.array(values)
348
- return np.convolve(values, np.ones(window) / window, mode="valid")
349
-
350
- fig, axes = plt.subplots(1, 4, figsize=(20, 4))
351
- fig.suptitle("Ambulance Green Corridor — GRPO Training", fontsize=14, fontweight="bold")
352
- iters = history["iteration"]
353
- sm = 4
354
-
355
- ax = axes[0]
356
- ax.plot(iters, history["mean_reward"], alpha=0.25, color="royalblue")
357
- ax.plot(iters[sm:], smooth(history["mean_reward"]), color="royalblue", lw=2, label="Trained")
358
- ax.axhline(baseline["mean_reward"], color="red", ls="--", lw=1.5, label=f"Baseline ({baseline['mean_reward']:.0f})")
359
- ax.axhline(final["mean_reward"], color="green", ls="--", lw=1.5, label=f"Final ({final['mean_reward']:.0f})")
360
- ax.set_xlabel("Episode"); ax.set_ylabel("Reward"); ax.set_title("Episode Reward"); ax.legend(fontsize=8)
361
-
362
- ax = axes[1]
363
- ax.plot(iters, [v * 100 for v in history["arrival_rate"]], alpha=0.25, color="darkorange")
364
- ax.plot(iters[sm:], smooth([v * 100 for v in history["arrival_rate"]]), color="darkorange", lw=2)
365
- ax.axhline(baseline["arrival_rate"] * 100, color="red", ls="--", lw=1.5, label=f"Before ({baseline['arrival_rate']:.0%})")
366
- ax.axhline(final["arrival_rate"] * 100, color="green", ls="--", lw=1.5, label=f"After ({final['arrival_rate']:.0%})")
367
- ax.set_xlabel("Episode"); ax.set_ylabel("Arrival (%)"); ax.set_title("Hospital Arrival Rate"); ax.set_ylim(0, 105); ax.legend(fontsize=8)
368
-
369
- ax = axes[2]
370
- ax.plot(iters, [v * 100 for v in history["signal_efficiency"]], alpha=0.25, color="seagreen")
371
- ax.plot(iters[sm:], smooth([v * 100 for v in history["signal_efficiency"]]), color="seagreen", lw=2)
372
- ax.axhline(baseline["mean_efficiency"] * 100, color="red", ls="--", lw=1.5, label=f"Before ({baseline['mean_efficiency']:.0%})")
373
- ax.axhline(final["mean_efficiency"] * 100, color="green", ls="--", lw=1.5, label=f"After ({final['mean_efficiency']:.0%})")
374
- ax.set_xlabel("Episode"); ax.set_ylabel("Efficiency (%)"); ax.set_title("Signal Efficiency"); ax.set_ylim(0, 105); ax.legend(fontsize=8)
375
-
376
- ax = axes[3]
377
- ax.plot(iters, history["mean_reroutes"], alpha=0.25, color="purple")
378
- ax.plot(iters[sm:], smooth(history["mean_reroutes"]), color="purple", lw=2)
379
- ax.axhline(baseline["mean_reroutes"], color="red", ls="--", lw=1.5, label=f"Before ({baseline['mean_reroutes']:.1f})")
380
- ax.axhline(final["mean_reroutes"], color="green", ls="--", lw=1.5, label=f"After ({final['mean_reroutes']:.1f})")
381
- ax.set_xlabel("Episode"); ax.set_ylabel("Reroutes"); ax.set_title("Adaptive Re-routing"); ax.legend(fontsize=8)
382
-
383
- plt.tight_layout()
384
- out = PLOT_DIR / "ambulance_training_results.png"
385
- plt.savefig(out, dpi=150, bbox_inches="tight")
386
- print(f"\nPlot saved → {out}")
387
-
388
- # Save results as JSON for the web UI
389
- import json as _json
390
- results = {"baseline": baseline, "final": final, "history": history}
391
- with open(PLOT_DIR / "results.json", "w") as f:
392
- _json.dump(results, f, indent=2)
393
- print(f"Results saved → {PLOT_DIR / 'results.json'}")
394
-
395
- # ── Cleanup ──────────────────────────────────────────────────────────────
396
- server_proc.terminate()
397
- print("\nTraining complete. Server stopped.")