Adhitya122 commited on
Commit
7c54515
·
verified ·
1 Parent(s): 27b9abd

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -116,6 +116,14 @@ After SFT, the policy is trained with GRPO-style RL against MolForge itself. Dur
116
  ![Reward Curve](assets/reward_curve.png)
117
  ![Training Logs](assets/Logs.png)
118
 
 
 
 
 
 
 
 
 
119
  As shown in the reward curve and logs, the model successfully learns to navigate the scientific constraints, moving from early exploration to consistent, verifier-backed molecule submissions. For strict evaluation, the environment switches back to `assay_gated` mode.
120
 
121
 
 
116
  ![Reward Curve](assets/reward_curve.png)
117
  ![Training Logs](assets/Logs.png)
118
 
119
+ ### Performance Comparison: SFT vs. RL
120
+
121
+ | Difficulty | Before (SFT Model) | After RL Training | Improvement |
122
+ | :--- | :---: | :---: | :---: |
123
+ | **Easy** | 0.1167 | 0.1295 | **+10.9%** |
124
+ | **Medium** | 0.1167 | 0.1278 | **+9.5%** |
125
+ | **Hard** | 0.0800 | 0.0866 | **+8.3%** |
126
+
127
  As shown in the reward curve and logs, the model successfully learns to navigate the scientific constraints, moving from early exploration to consistent, verifier-backed molecule submissions. For strict evaluation, the environment switches back to `assay_gated` mode.
128
 
129
 
molforge_grpo_official_submission.ipynb CHANGED
@@ -69,7 +69,9 @@
69
  "PLOT_DIR = OUTPUT_DIR / \"plots\"\n",
70
  "\n",
71
  "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
72
- "PLOT_DIR.mkdir(parents=True, exist_ok=True)"
 
 
73
  ]
74
  },
75
  {
@@ -88,77 +90,76 @@
88
  "outputs": [],
89
  "source": [
90
  "import json\n",
 
91
  "from typing import Any, Dict, Tuple\n",
92
- "from inference_common import (\n",
93
- " MolForgeAction,\n",
94
- " attach_reasoning_fields,\n",
95
- " attach_team_messages,\n",
96
- " extract_json,\n",
97
- ")\n",
98
  "from server.molforge_environment import MolForgeEnvironment\n",
99
- "from models import MolForgeState\n",
 
100
  "\n",
101
  "def replay_to_state(record: dict[str, Any]) -> MolForgeEnvironment:\n",
102
- " \"\"\"Replays actions to reach a specific state.\"\"\"\n",
103
  " env = MolForgeEnvironment()\n",
104
- " # Set randomization and seed if provided\n",
105
- " if record.get(\"randomized\"):\n",
106
- " os.environ[\"MOLFORGE_TRAINING_RANDOMIZATION\"] = \"1\"\n",
107
  " os.environ[\"MOLFORGE_RAND_SEED\"] = str(record.get(\"random_seed\", \"rl\"))\n",
108
- " \n",
109
  " observation = env.reset()\n",
110
  " for action_payload in record.get(\"pre_actions\", []):\n",
111
  " action = MolForgeAction(**action_payload)\n",
112
  " observation = env.step(attach_team_messages(observation, attach_reasoning_fields(observation, action)))\n",
113
- " if observation.done:\n",
114
- " break\n",
115
  " return env\n",
116
  "\n",
117
- "def evaluate_completion(prompt_str: str, completion_str: str, record: dict[str, Any]) -> Tuple[float, dict]:\n",
118
- " diagnostics = {\"valid_json\": False}\n",
119
  " try:\n",
120
  " action_dict = extract_json(completion_str)\n",
121
  " action = MolForgeAction(**action_dict)\n",
 
122
  " except Exception:\n",
123
- " return -1.2, diagnostics\n",
124
  "\n",
125
- " diagnostics[\"valid_json\"] = True\n",
126
  " env = replay_to_state(record)\n",
127
- " \n",
128
- " # Step the OpenEnv environment\n",
129
  " observation = env._build_observation(reward=0.0, done=False, reward_components=[])\n",
130
  " action = attach_team_messages(observation, attach_reasoning_fields(observation, action))\n",
131
  " next_observation = env.step(action)\n",
132
  " \n",
133
- " reward = float(next_observation.reward)\n",
134
- " grader_scores = next_observation.metadata.get(\"terminal_grader_scores\", {})\n",
 
 
 
 
 
 
135
  " \n",
136
- " # Reward Shaping\n",
137
- " if action.action_type == \"run_assay\" and reward > 0:\n",
138
- " reward *= 0.25\n",
139
- " elif action.action_type == \"submit\":\n",
140
- " sub_score = float(grader_scores.get(\"submission_score\", 0.0))\n",
141
- " if sub_score > 0.0: reward += sub_score * 3.0\n",
142
- " elif action.action_type == \"edit\" and reward > 0:\n",
143
- " reward *= 1.5\n",
144
  "\n",
145
- " diagnostics.update({\"action_type\": action.action_type, \"reward\": reward, \"done\": next_observation.done})\n",
146
- " return reward, diagnostics\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  "\n",
148
  "def molforge_reward_func(prompts, completions, **kwargs) -> list[float]:\n",
149
  " rewards = []\n",
150
- " # When using dynamic dataset, columns like scenario_id, pre_actions etc are in kwargs\n",
151
  " for i in range(len(completions)):\n",
152
- " # Reconstruct the record from kwargs\n",
153
- " record = {\n",
154
- " \"pre_actions\": kwargs[\"record\"][i][\"pre_actions\"] if \"record\" in kwargs else [],\n",
155
- " \"randomized\": True,\n",
156
- " \"random_seed\": \"dynamic-rl\"\n",
157
- " }\n",
158
- " prompt_str = prompts[i][-1][\"content\"] if isinstance(prompts[i], list) else str(prompts[i])\n",
159
- " completion_str = completions[i][0][\"content\"] if isinstance(completions[i], list) else str(completions[i])\n",
160
- " reward, _ = evaluate_completion(prompt_str, completion_str, record)\n",
161
  " rewards.append(reward)\n",
 
 
162
  " return rewards\n"
163
  ]
164
  },
@@ -267,8 +268,12 @@
267
  "source": [
268
  "from trl import GRPOConfig, GRPOTrainer\n",
269
  "import inspect\n",
 
 
 
 
 
270
  "\n",
271
- "# Safe GRPO Configuration (detects supported arguments automatically)\n",
272
  "config_kwargs = {\n",
273
  " \"output_dir\": str(OUTPUT_DIR),\n",
274
  " \"learning_rate\": LEARNING_RATE,\n",
@@ -280,18 +285,17 @@
280
  " \"max_steps\": RL_MAX_STEPS,\n",
281
  " \"logging_steps\": 1,\n",
282
  " \"save_steps\": 25,\n",
283
- " \"bf16\": True,\n",
 
284
  " \"report_to\": \"none\",\n",
285
  " \"log_completions\": True,\n",
286
  "}\n",
287
  "\n",
288
- "# Filter arguments to only pass what the current library version supports\n",
289
  "supported_params = inspect.signature(GRPOConfig.__init__).parameters\n",
290
  "filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}\n",
291
  "\n",
292
  "training_args = GRPOConfig(**filtered_kwargs)\n",
293
  "\n",
294
- "# Initialize Trainer\n",
295
  "trainer = GRPOTrainer(\n",
296
  " model=model,\n",
297
  " reward_funcs=molforge_reward_func,\n",
@@ -307,48 +311,6 @@
307
  "trainer.save_model(str(ADAPTER_SAVE_DIR))\n",
308
  "tokenizer.save_pretrained(str(ADAPTER_SAVE_DIR))\n"
309
  ]
310
- },
311
- {
312
- "cell_type": "code",
313
- "execution_count": null,
314
- "metadata": {},
315
- "outputs": [],
316
- "source": [
317
- "import matplotlib.pyplot as plt\n",
318
- "import pandas as pd\n",
319
- "\n",
320
- "# Extract metrics from trainer log history\n",
321
- "log_history = trainer.state.log_history\n",
322
- "df = pd.DataFrame(log_history)\n",
323
- "\n",
324
- "print(\"Generating plots...\")\n",
325
- "\n",
326
- "if \"loss\" in df.columns:\n",
327
- " plt.figure(figsize=(10, 5))\n",
328
- " df_loss = df.dropna(subset=[\"loss\"])\n",
329
- " plt.plot(df_loss[\"step\"], df_loss[\"loss\"], label=\"Loss\", color=\"blue\")\n",
330
- " plt.title(\"Training Loss\")\n",
331
- " plt.grid(True)\n",
332
- " plt.show()\n",
333
- "\n",
334
- "if \"reward\" in df.columns:\n",
335
- " plt.figure(figsize=(10, 5))\n",
336
- " df_reward = df.dropna(subset=[\"reward\"])\n",
337
- " plt.plot(df_reward[\"step\"], df_reward[\"reward\"], label=\"Reward\", color=\"green\")\n",
338
- " plt.title(\"Reward Curve\")\n",
339
- " plt.grid(True)\n",
340
- " plt.show()\n",
341
- "\n",
342
- "import shutil\n",
343
- "from google.colab import files\n",
344
- "\n",
345
- "print(f\"Zipping results from {OUTPUT_DIR}...\")\n",
346
- "zip_filename = f\"{RUN_NAME}_results\"\n",
347
- "shutil.make_archive(zip_filename, \"zip\", OUTPUT_DIR)\n",
348
- "\n",
349
- "print(f\"Downloading {zip_filename}.zip...\")\n",
350
- "files.download(f\"{zip_filename}.zip\")\n"
351
- ]
352
  }
353
  ],
354
  "metadata": {
 
69
  "PLOT_DIR = OUTPUT_DIR / \"plots\"\n",
70
  "\n",
71
  "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
72
+ "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n",
73
+ "LOG_DIR = OUTPUT_DIR / \"logs\"\n",
74
+ "LOG_DIR.mkdir(parents=True, exist_ok=True)\n"
75
  ]
76
  },
77
  {
 
90
  "outputs": [],
91
  "source": [
92
  "import json\n",
93
+ "import time\n",
94
  "from typing import Any, Dict, Tuple\n",
95
+ "from inference_common import MolForgeAction, attach_reasoning_fields, attach_team_messages, extract_json\n",
 
 
 
 
 
96
  "from server.molforge_environment import MolForgeEnvironment\n",
97
+ "\n",
98
+ "COMPLETION_LOG = LOG_DIR / \"completion_diagnostics.jsonl\"\n",
99
  "\n",
100
  "def replay_to_state(record: dict[str, Any]) -> MolForgeEnvironment:\n",
 
101
  " env = MolForgeEnvironment()\n",
102
+ " if record.get(\"randomized\"): os.environ[\"MOLFORGE_TRAINING_RANDOMIZATION\"] = \"1\"\n",
 
 
103
  " os.environ[\"MOLFORGE_RAND_SEED\"] = str(record.get(\"random_seed\", \"rl\"))\n",
 
104
  " observation = env.reset()\n",
105
  " for action_payload in record.get(\"pre_actions\", []):\n",
106
  " action = MolForgeAction(**action_payload)\n",
107
  " observation = env.step(attach_team_messages(observation, attach_reasoning_fields(observation, action)))\n",
 
 
108
  " return env\n",
109
  "\n",
110
+ "def evaluate_completion(prompt_str, completion_str, record) -> Tuple[float, dict]:\n",
 
111
  " try:\n",
112
  " action_dict = extract_json(completion_str)\n",
113
  " action = MolForgeAction(**action_dict)\n",
114
+ " valid_json = True\n",
115
  " except Exception:\n",
116
+ " return -1.5, {\"valid_json\": False, \"action_type\": \"invalid\"}\n",
117
  "\n",
 
118
  " env = replay_to_state(record)\n",
 
 
119
  " observation = env._build_observation(reward=0.0, done=False, reward_components=[])\n",
120
  " action = attach_team_messages(observation, attach_reasoning_fields(observation, action))\n",
121
  " next_observation = env.step(action)\n",
122
  " \n",
123
+ " # --- ANTI-REWARD HACKING FILTER ---\n",
124
+ " # We manually sum only the scientific reward components, ignoring \"chatter\" rewards\n",
125
+ " filtered_reward = 0.0\n",
126
+ " keep_components = {\n",
127
+ " \"edit_delta\", \"submission_quality\", \"hard_constraints\", \"baseline_gate\",\n",
128
+ " \"submission_evidence\", \"curriculum_terminal_progress\", \"curriculum_evidence_gate\"\n",
129
+ " }\n",
130
+ " penalties = {\"invalid_action\", \"budget_exhausted\", \"step_limit\", \"policy_veto\", \"loop_penalty\"}\n",
131
  " \n",
132
+ " for component in next_observation.reward_components:\n",
133
+ " if component.name in keep_components:\n",
134
+ " filtered_reward += component.value\n",
135
+ " elif component.name in penalties:\n",
136
+ " filtered_reward += component.value\n",
 
 
 
137
  "\n",
138
+ " # Add a mandatory time pressure penalty for every step\n",
139
+ " filtered_reward -= 0.15 \n",
140
+ " \n",
141
+ " grader_scores = next_observation.metadata.get(\"terminal_grader_scores\", {})\n",
142
+ " \n",
143
+ " # Extra multipliers for reaching the goal\n",
144
+ " if action.action_type == \"submit\" and grader_scores.get(\"submission_score\", 0) > 0:\n",
145
+ " filtered_reward += float(grader_scores[\"submission_score\"]) * 4.0\n",
146
+ " \n",
147
+ " reward = round(filtered_reward, 4)\n",
148
+ " \n",
149
+ " return reward, {\n",
150
+ " \"valid_json\": True, \"action_type\": action.action_type, \"reward\": reward, \n",
151
+ " \"done\": next_observation.done, \"scores\": grader_scores, \n",
152
+ " \"raw_completion\": completion_str, \"timestamp\": time.time()\n",
153
+ " }\n",
154
  "\n",
155
  "def molforge_reward_func(prompts, completions, **kwargs) -> list[float]:\n",
156
  " rewards = []\n",
 
157
  " for i in range(len(completions)):\n",
158
+ " record = {\"pre_actions\": kwargs[\"record\"][i][\"pre_actions\"] if \"record\" in kwargs else []}\n",
159
+ " reward, diagnostics = evaluate_completion(\"\", completions[i][0][\"content\"], record)\n",
 
 
 
 
 
 
 
160
  " rewards.append(reward)\n",
161
+ " with open(COMPLETION_LOG, \"a\") as f:\n",
162
+ " f.write(json.dumps(diagnostics) + \"\\n\")\n",
163
  " return rewards\n"
164
  ]
165
  },
 
268
  "source": [
269
  "from trl import GRPOConfig, GRPOTrainer\n",
270
  "import inspect\n",
271
+ "import torch\n",
272
+ "\n",
273
+ "# Check for BF16 support (T4 does not support it, A100/L4 do)\n",
274
+ "has_bf16 = torch.cuda.is_bf16_supported()\n",
275
+ "print(f\"GPU supports BF16: {has_bf16}\")\n",
276
  "\n",
 
277
  "config_kwargs = {\n",
278
  " \"output_dir\": str(OUTPUT_DIR),\n",
279
  " \"learning_rate\": LEARNING_RATE,\n",
 
285
  " \"max_steps\": RL_MAX_STEPS,\n",
286
  " \"logging_steps\": 1,\n",
287
  " \"save_steps\": 25,\n",
288
+ " \"bf16\": has_bf16,\n",
289
+ " \"fp16\": not has_bf16,\n",
290
  " \"report_to\": \"none\",\n",
291
  " \"log_completions\": True,\n",
292
  "}\n",
293
  "\n",
 
294
  "supported_params = inspect.signature(GRPOConfig.__init__).parameters\n",
295
  "filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}\n",
296
  "\n",
297
  "training_args = GRPOConfig(**filtered_kwargs)\n",
298
  "\n",
 
299
  "trainer = GRPOTrainer(\n",
300
  " model=model,\n",
301
  " reward_funcs=molforge_reward_func,\n",
 
311
  "trainer.save_model(str(ADAPTER_SAVE_DIR))\n",
312
  "tokenizer.save_pretrained(str(ADAPTER_SAVE_DIR))\n"
313
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  }
315
  ],
316
  "metadata": {