Upload folder using huggingface_hub
Browse files- README.md +8 -0
- molforge_grpo_official_submission.ipynb +51 -89
README.md
CHANGED
|
@@ -116,6 +116,14 @@ After SFT, the policy is trained with GRPO-style RL against MolForge itself. Dur
|
|
| 116 |

|
| 117 |

|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
As shown in the reward curve and logs, the model successfully learns to navigate the scientific constraints, moving from early exploration to consistent, verifier-backed molecule submissions. For strict evaluation, the environment switches back to `assay_gated` mode.
|
| 120 |
|
| 121 |
|
|
|
|
| 116 |

|
| 117 |

|
| 118 |
|
| 119 |
+
### Performance Comparison: SFT vs. RL
|
| 120 |
+
|
| 121 |
+
| Difficulty | Before (SFT Model) | After RL Training | Improvement |
|
| 122 |
+
| :--- | :---: | :---: | :---: |
|
| 123 |
+
| **Easy** | 0.1167 | 0.1295 | **+10.9%** |
|
| 124 |
+
| **Medium** | 0.1167 | 0.1278 | **+9.5%** |
|
| 125 |
+
| **Hard** | 0.0800 | 0.0866 | **+8.3%** |
|
| 126 |
+
|
| 127 |
As shown in the reward curve and logs, the model successfully learns to navigate the scientific constraints, moving from early exploration to consistent, verifier-backed molecule submissions. For strict evaluation, the environment switches back to `assay_gated` mode.
|
| 128 |
|
| 129 |
|
molforge_grpo_official_submission.ipynb
CHANGED
|
@@ -69,7 +69,9 @@
|
|
| 69 |
"PLOT_DIR = OUTPUT_DIR / \"plots\"\n",
|
| 70 |
"\n",
|
| 71 |
"OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 72 |
-
"PLOT_DIR.mkdir(parents=True, exist_ok=True)"
|
|
|
|
|
|
|
| 73 |
]
|
| 74 |
},
|
| 75 |
{
|
|
@@ -88,77 +90,76 @@
|
|
| 88 |
"outputs": [],
|
| 89 |
"source": [
|
| 90 |
"import json\n",
|
|
|
|
| 91 |
"from typing import Any, Dict, Tuple\n",
|
| 92 |
-
"from inference_common import
|
| 93 |
-
" MolForgeAction,\n",
|
| 94 |
-
" attach_reasoning_fields,\n",
|
| 95 |
-
" attach_team_messages,\n",
|
| 96 |
-
" extract_json,\n",
|
| 97 |
-
")\n",
|
| 98 |
"from server.molforge_environment import MolForgeEnvironment\n",
|
| 99 |
-
"
|
|
|
|
| 100 |
"\n",
|
| 101 |
"def replay_to_state(record: dict[str, Any]) -> MolForgeEnvironment:\n",
|
| 102 |
-
" \"\"\"Replays actions to reach a specific state.\"\"\"\n",
|
| 103 |
" env = MolForgeEnvironment()\n",
|
| 104 |
-
"
|
| 105 |
-
" if record.get(\"randomized\"):\n",
|
| 106 |
-
" os.environ[\"MOLFORGE_TRAINING_RANDOMIZATION\"] = \"1\"\n",
|
| 107 |
" os.environ[\"MOLFORGE_RAND_SEED\"] = str(record.get(\"random_seed\", \"rl\"))\n",
|
| 108 |
-
" \n",
|
| 109 |
" observation = env.reset()\n",
|
| 110 |
" for action_payload in record.get(\"pre_actions\", []):\n",
|
| 111 |
" action = MolForgeAction(**action_payload)\n",
|
| 112 |
" observation = env.step(attach_team_messages(observation, attach_reasoning_fields(observation, action)))\n",
|
| 113 |
-
" if observation.done:\n",
|
| 114 |
-
" break\n",
|
| 115 |
" return env\n",
|
| 116 |
"\n",
|
| 117 |
-
"def evaluate_completion(prompt_str
|
| 118 |
-
" diagnostics = {\"valid_json\": False}\n",
|
| 119 |
" try:\n",
|
| 120 |
" action_dict = extract_json(completion_str)\n",
|
| 121 |
" action = MolForgeAction(**action_dict)\n",
|
|
|
|
| 122 |
" except Exception:\n",
|
| 123 |
-
" return -1.
|
| 124 |
"\n",
|
| 125 |
-
" diagnostics[\"valid_json\"] = True\n",
|
| 126 |
" env = replay_to_state(record)\n",
|
| 127 |
-
" \n",
|
| 128 |
-
" # Step the OpenEnv environment\n",
|
| 129 |
" observation = env._build_observation(reward=0.0, done=False, reward_components=[])\n",
|
| 130 |
" action = attach_team_messages(observation, attach_reasoning_fields(observation, action))\n",
|
| 131 |
" next_observation = env.step(action)\n",
|
| 132 |
" \n",
|
| 133 |
-
"
|
| 134 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
" \n",
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
" if sub_score > 0.0: reward += sub_score * 3.0\n",
|
| 142 |
-
" elif action.action_type == \"edit\" and reward > 0:\n",
|
| 143 |
-
" reward *= 1.5\n",
|
| 144 |
"\n",
|
| 145 |
-
"
|
| 146 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
"\n",
|
| 148 |
"def molforge_reward_func(prompts, completions, **kwargs) -> list[float]:\n",
|
| 149 |
" rewards = []\n",
|
| 150 |
-
" # When using dynamic dataset, columns like scenario_id, pre_actions etc are in kwargs\n",
|
| 151 |
" for i in range(len(completions)):\n",
|
| 152 |
-
"
|
| 153 |
-
"
|
| 154 |
-
" \"pre_actions\": kwargs[\"record\"][i][\"pre_actions\"] if \"record\" in kwargs else [],\n",
|
| 155 |
-
" \"randomized\": True,\n",
|
| 156 |
-
" \"random_seed\": \"dynamic-rl\"\n",
|
| 157 |
-
" }\n",
|
| 158 |
-
" prompt_str = prompts[i][-1][\"content\"] if isinstance(prompts[i], list) else str(prompts[i])\n",
|
| 159 |
-
" completion_str = completions[i][0][\"content\"] if isinstance(completions[i], list) else str(completions[i])\n",
|
| 160 |
-
" reward, _ = evaluate_completion(prompt_str, completion_str, record)\n",
|
| 161 |
" rewards.append(reward)\n",
|
|
|
|
|
|
|
| 162 |
" return rewards\n"
|
| 163 |
]
|
| 164 |
},
|
|
@@ -267,8 +268,12 @@
|
|
| 267 |
"source": [
|
| 268 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 269 |
"import inspect\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
"\n",
|
| 271 |
-
"# Safe GRPO Configuration (detects supported arguments automatically)\n",
|
| 272 |
"config_kwargs = {\n",
|
| 273 |
" \"output_dir\": str(OUTPUT_DIR),\n",
|
| 274 |
" \"learning_rate\": LEARNING_RATE,\n",
|
|
@@ -280,18 +285,17 @@
|
|
| 280 |
" \"max_steps\": RL_MAX_STEPS,\n",
|
| 281 |
" \"logging_steps\": 1,\n",
|
| 282 |
" \"save_steps\": 25,\n",
|
| 283 |
-
" \"bf16\":
|
|
|
|
| 284 |
" \"report_to\": \"none\",\n",
|
| 285 |
" \"log_completions\": True,\n",
|
| 286 |
"}\n",
|
| 287 |
"\n",
|
| 288 |
-
"# Filter arguments to only pass what the current library version supports\n",
|
| 289 |
"supported_params = inspect.signature(GRPOConfig.__init__).parameters\n",
|
| 290 |
"filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}\n",
|
| 291 |
"\n",
|
| 292 |
"training_args = GRPOConfig(**filtered_kwargs)\n",
|
| 293 |
"\n",
|
| 294 |
-
"# Initialize Trainer\n",
|
| 295 |
"trainer = GRPOTrainer(\n",
|
| 296 |
" model=model,\n",
|
| 297 |
" reward_funcs=molforge_reward_func,\n",
|
|
@@ -307,48 +311,6 @@
|
|
| 307 |
"trainer.save_model(str(ADAPTER_SAVE_DIR))\n",
|
| 308 |
"tokenizer.save_pretrained(str(ADAPTER_SAVE_DIR))\n"
|
| 309 |
]
|
| 310 |
-
},
|
| 311 |
-
{
|
| 312 |
-
"cell_type": "code",
|
| 313 |
-
"execution_count": null,
|
| 314 |
-
"metadata": {},
|
| 315 |
-
"outputs": [],
|
| 316 |
-
"source": [
|
| 317 |
-
"import matplotlib.pyplot as plt\n",
|
| 318 |
-
"import pandas as pd\n",
|
| 319 |
-
"\n",
|
| 320 |
-
"# Extract metrics from trainer log history\n",
|
| 321 |
-
"log_history = trainer.state.log_history\n",
|
| 322 |
-
"df = pd.DataFrame(log_history)\n",
|
| 323 |
-
"\n",
|
| 324 |
-
"print(\"Generating plots...\")\n",
|
| 325 |
-
"\n",
|
| 326 |
-
"if \"loss\" in df.columns:\n",
|
| 327 |
-
" plt.figure(figsize=(10, 5))\n",
|
| 328 |
-
" df_loss = df.dropna(subset=[\"loss\"])\n",
|
| 329 |
-
" plt.plot(df_loss[\"step\"], df_loss[\"loss\"], label=\"Loss\", color=\"blue\")\n",
|
| 330 |
-
" plt.title(\"Training Loss\")\n",
|
| 331 |
-
" plt.grid(True)\n",
|
| 332 |
-
" plt.show()\n",
|
| 333 |
-
"\n",
|
| 334 |
-
"if \"reward\" in df.columns:\n",
|
| 335 |
-
" plt.figure(figsize=(10, 5))\n",
|
| 336 |
-
" df_reward = df.dropna(subset=[\"reward\"])\n",
|
| 337 |
-
" plt.plot(df_reward[\"step\"], df_reward[\"reward\"], label=\"Reward\", color=\"green\")\n",
|
| 338 |
-
" plt.title(\"Reward Curve\")\n",
|
| 339 |
-
" plt.grid(True)\n",
|
| 340 |
-
" plt.show()\n",
|
| 341 |
-
"\n",
|
| 342 |
-
"import shutil\n",
|
| 343 |
-
"from google.colab import files\n",
|
| 344 |
-
"\n",
|
| 345 |
-
"print(f\"Zipping results from {OUTPUT_DIR}...\")\n",
|
| 346 |
-
"zip_filename = f\"{RUN_NAME}_results\"\n",
|
| 347 |
-
"shutil.make_archive(zip_filename, \"zip\", OUTPUT_DIR)\n",
|
| 348 |
-
"\n",
|
| 349 |
-
"print(f\"Downloading {zip_filename}.zip...\")\n",
|
| 350 |
-
"files.download(f\"{zip_filename}.zip\")\n"
|
| 351 |
-
]
|
| 352 |
}
|
| 353 |
],
|
| 354 |
"metadata": {
|
|
|
|
| 69 |
"PLOT_DIR = OUTPUT_DIR / \"plots\"\n",
|
| 70 |
"\n",
|
| 71 |
"OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 72 |
+
"PLOT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 73 |
+
"LOG_DIR = OUTPUT_DIR / \"logs\"\n",
|
| 74 |
+
"LOG_DIR.mkdir(parents=True, exist_ok=True)\n"
|
| 75 |
]
|
| 76 |
},
|
| 77 |
{
|
|
|
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
| 92 |
"import json\n",
|
| 93 |
+
"import time\n",
|
| 94 |
"from typing import Any, Dict, Tuple\n",
|
| 95 |
+
"from inference_common import MolForgeAction, attach_reasoning_fields, attach_team_messages, extract_json\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"from server.molforge_environment import MolForgeEnvironment\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"COMPLETION_LOG = LOG_DIR / \"completion_diagnostics.jsonl\"\n",
|
| 99 |
"\n",
|
| 100 |
"def replay_to_state(record: dict[str, Any]) -> MolForgeEnvironment:\n",
|
|
|
|
| 101 |
" env = MolForgeEnvironment()\n",
|
| 102 |
+
" if record.get(\"randomized\"): os.environ[\"MOLFORGE_TRAINING_RANDOMIZATION\"] = \"1\"\n",
|
|
|
|
|
|
|
| 103 |
" os.environ[\"MOLFORGE_RAND_SEED\"] = str(record.get(\"random_seed\", \"rl\"))\n",
|
|
|
|
| 104 |
" observation = env.reset()\n",
|
| 105 |
" for action_payload in record.get(\"pre_actions\", []):\n",
|
| 106 |
" action = MolForgeAction(**action_payload)\n",
|
| 107 |
" observation = env.step(attach_team_messages(observation, attach_reasoning_fields(observation, action)))\n",
|
|
|
|
|
|
|
| 108 |
" return env\n",
|
| 109 |
"\n",
|
| 110 |
+
"def evaluate_completion(prompt_str, completion_str, record) -> Tuple[float, dict]:\n",
|
|
|
|
| 111 |
" try:\n",
|
| 112 |
" action_dict = extract_json(completion_str)\n",
|
| 113 |
" action = MolForgeAction(**action_dict)\n",
|
| 114 |
+
" valid_json = True\n",
|
| 115 |
" except Exception:\n",
|
| 116 |
+
" return -1.5, {\"valid_json\": False, \"action_type\": \"invalid\"}\n",
|
| 117 |
"\n",
|
|
|
|
| 118 |
" env = replay_to_state(record)\n",
|
|
|
|
|
|
|
| 119 |
" observation = env._build_observation(reward=0.0, done=False, reward_components=[])\n",
|
| 120 |
" action = attach_team_messages(observation, attach_reasoning_fields(observation, action))\n",
|
| 121 |
" next_observation = env.step(action)\n",
|
| 122 |
" \n",
|
| 123 |
+
" # --- ANTI-REWARD HACKING FILTER ---\n",
|
| 124 |
+
" # We manually sum only the scientific reward components, ignoring \"chatter\" rewards\n",
|
| 125 |
+
" filtered_reward = 0.0\n",
|
| 126 |
+
" keep_components = {\n",
|
| 127 |
+
" \"edit_delta\", \"submission_quality\", \"hard_constraints\", \"baseline_gate\",\n",
|
| 128 |
+
" \"submission_evidence\", \"curriculum_terminal_progress\", \"curriculum_evidence_gate\"\n",
|
| 129 |
+
" }\n",
|
| 130 |
+
" penalties = {\"invalid_action\", \"budget_exhausted\", \"step_limit\", \"policy_veto\", \"loop_penalty\"}\n",
|
| 131 |
" \n",
|
| 132 |
+
" for component in next_observation.reward_components:\n",
|
| 133 |
+
" if component.name in keep_components:\n",
|
| 134 |
+
" filtered_reward += component.value\n",
|
| 135 |
+
" elif component.name in penalties:\n",
|
| 136 |
+
" filtered_reward += component.value\n",
|
|
|
|
|
|
|
|
|
|
| 137 |
"\n",
|
| 138 |
+
" # Add a mandatory time pressure penalty for every step\n",
|
| 139 |
+
" filtered_reward -= 0.15 \n",
|
| 140 |
+
" \n",
|
| 141 |
+
" grader_scores = next_observation.metadata.get(\"terminal_grader_scores\", {})\n",
|
| 142 |
+
" \n",
|
| 143 |
+
" # Extra multipliers for reaching the goal\n",
|
| 144 |
+
" if action.action_type == \"submit\" and grader_scores.get(\"submission_score\", 0) > 0:\n",
|
| 145 |
+
" filtered_reward += float(grader_scores[\"submission_score\"]) * 4.0\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" reward = round(filtered_reward, 4)\n",
|
| 148 |
+
" \n",
|
| 149 |
+
" return reward, {\n",
|
| 150 |
+
" \"valid_json\": True, \"action_type\": action.action_type, \"reward\": reward, \n",
|
| 151 |
+
" \"done\": next_observation.done, \"scores\": grader_scores, \n",
|
| 152 |
+
" \"raw_completion\": completion_str, \"timestamp\": time.time()\n",
|
| 153 |
+
" }\n",
|
| 154 |
"\n",
|
| 155 |
"def molforge_reward_func(prompts, completions, **kwargs) -> list[float]:\n",
|
| 156 |
" rewards = []\n",
|
|
|
|
| 157 |
" for i in range(len(completions)):\n",
|
| 158 |
+
" record = {\"pre_actions\": kwargs[\"record\"][i][\"pre_actions\"] if \"record\" in kwargs else []}\n",
|
| 159 |
+
" reward, diagnostics = evaluate_completion(\"\", completions[i][0][\"content\"], record)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
" rewards.append(reward)\n",
|
| 161 |
+
" with open(COMPLETION_LOG, \"a\") as f:\n",
|
| 162 |
+
" f.write(json.dumps(diagnostics) + \"\\n\")\n",
|
| 163 |
" return rewards\n"
|
| 164 |
]
|
| 165 |
},
|
|
|
|
| 268 |
"source": [
|
| 269 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 270 |
"import inspect\n",
|
| 271 |
+
"import torch\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"# Check for BF16 support (T4 does not support it, A100/L4 do)\n",
|
| 274 |
+
"has_bf16 = torch.cuda.is_bf16_supported()\n",
|
| 275 |
+
"print(f\"GPU supports BF16: {has_bf16}\")\n",
|
| 276 |
"\n",
|
|
|
|
| 277 |
"config_kwargs = {\n",
|
| 278 |
" \"output_dir\": str(OUTPUT_DIR),\n",
|
| 279 |
" \"learning_rate\": LEARNING_RATE,\n",
|
|
|
|
| 285 |
" \"max_steps\": RL_MAX_STEPS,\n",
|
| 286 |
" \"logging_steps\": 1,\n",
|
| 287 |
" \"save_steps\": 25,\n",
|
| 288 |
+
" \"bf16\": has_bf16,\n",
|
| 289 |
+
" \"fp16\": not has_bf16,\n",
|
| 290 |
" \"report_to\": \"none\",\n",
|
| 291 |
" \"log_completions\": True,\n",
|
| 292 |
"}\n",
|
| 293 |
"\n",
|
|
|
|
| 294 |
"supported_params = inspect.signature(GRPOConfig.__init__).parameters\n",
|
| 295 |
"filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}\n",
|
| 296 |
"\n",
|
| 297 |
"training_args = GRPOConfig(**filtered_kwargs)\n",
|
| 298 |
"\n",
|
|
|
|
| 299 |
"trainer = GRPOTrainer(\n",
|
| 300 |
" model=model,\n",
|
| 301 |
" reward_funcs=molforge_reward_func,\n",
|
|
|
|
| 311 |
"trainer.save_model(str(ADAPTER_SAVE_DIR))\n",
|
| 312 |
"tokenizer.save_pretrained(str(ADAPTER_SAVE_DIR))\n"
|
| 313 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
}
|
| 315 |
],
|
| 316 |
"metadata": {
|