zhangzf01 commited on
Commit Β·
7ffd03d
1
Parent(s): 777420e
accelerate env
Browse files
agent_system/multi_turn_rollout/rollout_loop.py
CHANGED
|
@@ -328,11 +328,17 @@ class TrajectoryCollector:
|
|
| 328 |
episode_lengths = np.zeros(batch_size, dtype=np.float32)
|
| 329 |
episode_rewards = np.zeros(batch_size, dtype=np.float32)
|
| 330 |
tool_callings = np.zeros(batch_size, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
# Trajectory collection loop
|
| 332 |
for _step in range(self.config.env.max_steps):
|
| 333 |
active_masks = np.logical_not(is_done)
|
| 334 |
|
|
|
|
| 335 |
batch = self.preprocess_batch(gen_batch=gen_batch, obs=obs)
|
|
|
|
| 336 |
|
| 337 |
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
|
| 338 |
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
|
|
@@ -351,7 +357,9 @@ class TrajectoryCollector:
|
|
| 351 |
|
| 352 |
# pad to be divisible by dp_size
|
| 353 |
batch_input_padded, pad_size = pad_dataproto_to_divisor(batch_input, actor_rollout_wg.world_size)
|
|
|
|
| 354 |
batch_output_padded = actor_rollout_wg.generate_sequences(batch_input_padded)
|
|
|
|
| 355 |
# # unpad
|
| 356 |
batch_output = unpad_dataproto(batch_output_padded, pad_size=pad_size)
|
| 357 |
|
|
@@ -359,10 +367,12 @@ class TrajectoryCollector:
|
|
| 359 |
batch.non_tensor_batch['traj_uid'] = traj_uid
|
| 360 |
|
| 361 |
batch = batch.union(batch_output)
|
| 362 |
-
|
| 363 |
text_actions = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True)
|
| 364 |
-
|
|
|
|
| 365 |
next_obs, rewards, dones, infos = envs.step(text_actions)
|
|
|
|
| 366 |
|
| 367 |
|
| 368 |
if len(rewards.shape) == 2:
|
|
@@ -411,8 +421,9 @@ class TrajectoryCollector:
|
|
| 411 |
episode_lengths=episode_lengths,
|
| 412 |
)
|
| 413 |
|
| 414 |
-
|
| 415 |
-
|
|
|
|
| 416 |
def dynamic_multi_turn_loop(
|
| 417 |
self,
|
| 418 |
gen_batch: DataProto,
|
|
@@ -451,7 +462,7 @@ class TrajectoryCollector:
|
|
| 451 |
print(f"valid num={len(total_batch_list)} < target num={self.config.data.train_batch_size * self.config.env.rollout.n}. Keep generating... ({try_count}/{max_try_count})")
|
| 452 |
try_count += 1
|
| 453 |
|
| 454 |
-
batch_list, episode_rewards, episode_lengths, success, traj_uid, tool_callings = self.vanilla_multi_turn_loop(
|
| 455 |
gen_batch=gen_batch,
|
| 456 |
actor_rollout_wg=actor_rollout_wg,
|
| 457 |
envs=envs,
|
|
@@ -479,7 +490,7 @@ class TrajectoryCollector:
|
|
| 479 |
total_traj_uid = np.concatenate(total_traj_uid, axis=0)
|
| 480 |
total_tool_callings = np.concatenate(total_tool_callings, axis=0)
|
| 481 |
|
| 482 |
-
return total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, total_tool_callings
|
| 483 |
|
| 484 |
def multi_turn_loop(
|
| 485 |
self,
|
|
@@ -506,15 +517,15 @@ class TrajectoryCollector:
|
|
| 506 |
# Initial observations from the environment
|
| 507 |
if self.config.algorithm.filter_groups.enable and is_train:
|
| 508 |
# Dynamic Sampling (for DAPO and Dynamic GiGPO)
|
| 509 |
-
total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, totoal_tool_callings = \
|
| 510 |
self.dynamic_multi_turn_loop(
|
| 511 |
gen_batch=gen_batch,
|
| 512 |
actor_rollout_wg=actor_rollout_wg,
|
| 513 |
envs=envs,
|
| 514 |
)
|
| 515 |
else:
|
| 516 |
-
# Vanilla Sampling
|
| 517 |
-
total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, totoal_tool_callings = \
|
| 518 |
self.vanilla_multi_turn_loop(
|
| 519 |
gen_batch=gen_batch,
|
| 520 |
actor_rollout_wg=actor_rollout_wg,
|
|
@@ -524,7 +535,7 @@ class TrajectoryCollector:
|
|
| 524 |
assert len(total_batch_list) == len(total_episode_lengths)
|
| 525 |
assert len(total_batch_list) == len(total_traj_uid)
|
| 526 |
assert len(total_batch_list) == len(totoal_tool_callings)
|
| 527 |
-
|
| 528 |
|
| 529 |
# Create trajectory data
|
| 530 |
gen_batch_output: DataProto = self.gather_rollout_data(
|
|
@@ -535,5 +546,9 @@ class TrajectoryCollector:
|
|
| 535 |
traj_uid=total_traj_uid,
|
| 536 |
tool_callings=totoal_tool_callings,
|
| 537 |
)
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
return gen_batch_output
|
|
|
|
| 328 |
episode_lengths = np.zeros(batch_size, dtype=np.float32)
|
| 329 |
episode_rewards = np.zeros(batch_size, dtype=np.float32)
|
| 330 |
tool_callings = np.zeros(batch_size, dtype=np.float32)
|
| 331 |
+
import time as _time
|
| 332 |
+
_total_preprocess_time = 0.0
|
| 333 |
+
_total_infer_time = 0.0
|
| 334 |
+
_total_env_time = 0.0
|
| 335 |
# Trajectory collection loop
|
| 336 |
for _step in range(self.config.env.max_steps):
|
| 337 |
active_masks = np.logical_not(is_done)
|
| 338 |
|
| 339 |
+
_t0 = _time.time()
|
| 340 |
batch = self.preprocess_batch(gen_batch=gen_batch, obs=obs)
|
| 341 |
+
_total_preprocess_time += _time.time() - _t0
|
| 342 |
|
| 343 |
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
|
| 344 |
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
|
|
|
|
| 357 |
|
| 358 |
# pad to be divisible by dp_size
|
| 359 |
batch_input_padded, pad_size = pad_dataproto_to_divisor(batch_input, actor_rollout_wg.world_size)
|
| 360 |
+
_t0 = _time.time()
|
| 361 |
batch_output_padded = actor_rollout_wg.generate_sequences(batch_input_padded)
|
| 362 |
+
_total_infer_time += _time.time() - _t0
|
| 363 |
# # unpad
|
| 364 |
batch_output = unpad_dataproto(batch_output_padded, pad_size=pad_size)
|
| 365 |
|
|
|
|
| 367 |
batch.non_tensor_batch['traj_uid'] = traj_uid
|
| 368 |
|
| 369 |
batch = batch.union(batch_output)
|
| 370 |
+
|
| 371 |
text_actions = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True)
|
| 372 |
+
|
| 373 |
+
_t0 = _time.time()
|
| 374 |
next_obs, rewards, dones, infos = envs.step(text_actions)
|
| 375 |
+
_total_env_time += _time.time() - _t0
|
| 376 |
|
| 377 |
|
| 378 |
if len(rewards.shape) == 2:
|
|
|
|
| 421 |
episode_lengths=episode_lengths,
|
| 422 |
)
|
| 423 |
|
| 424 |
+
rollout_timing = {"inference_s": _total_infer_time, "env_s": _total_env_time, "preprocess_s": _total_preprocess_time}
|
| 425 |
+
return total_batch_list, episode_rewards, episode_lengths, success, traj_uid, tool_callings, rollout_timing
|
| 426 |
+
|
| 427 |
def dynamic_multi_turn_loop(
|
| 428 |
self,
|
| 429 |
gen_batch: DataProto,
|
|
|
|
| 462 |
print(f"valid num={len(total_batch_list)} < target num={self.config.data.train_batch_size * self.config.env.rollout.n}. Keep generating... ({try_count}/{max_try_count})")
|
| 463 |
try_count += 1
|
| 464 |
|
| 465 |
+
batch_list, episode_rewards, episode_lengths, success, traj_uid, tool_callings, _ = self.vanilla_multi_turn_loop(
|
| 466 |
gen_batch=gen_batch,
|
| 467 |
actor_rollout_wg=actor_rollout_wg,
|
| 468 |
envs=envs,
|
|
|
|
| 490 |
total_traj_uid = np.concatenate(total_traj_uid, axis=0)
|
| 491 |
total_tool_callings = np.concatenate(total_tool_callings, axis=0)
|
| 492 |
|
| 493 |
+
return total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, total_tool_callings, {}
|
| 494 |
|
| 495 |
def multi_turn_loop(
|
| 496 |
self,
|
|
|
|
| 517 |
# Initial observations from the environment
|
| 518 |
if self.config.algorithm.filter_groups.enable and is_train:
|
| 519 |
# Dynamic Sampling (for DAPO and Dynamic GiGPO)
|
| 520 |
+
total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, totoal_tool_callings, rollout_timing = \
|
| 521 |
self.dynamic_multi_turn_loop(
|
| 522 |
gen_batch=gen_batch,
|
| 523 |
actor_rollout_wg=actor_rollout_wg,
|
| 524 |
envs=envs,
|
| 525 |
)
|
| 526 |
else:
|
| 527 |
+
# Vanilla Sampling
|
| 528 |
+
total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, totoal_tool_callings, rollout_timing = \
|
| 529 |
self.vanilla_multi_turn_loop(
|
| 530 |
gen_batch=gen_batch,
|
| 531 |
actor_rollout_wg=actor_rollout_wg,
|
|
|
|
| 535 |
assert len(total_batch_list) == len(total_episode_lengths)
|
| 536 |
assert len(total_batch_list) == len(total_traj_uid)
|
| 537 |
assert len(total_batch_list) == len(totoal_tool_callings)
|
| 538 |
+
|
| 539 |
|
| 540 |
# Create trajectory data
|
| 541 |
gen_batch_output: DataProto = self.gather_rollout_data(
|
|
|
|
| 546 |
traj_uid=total_traj_uid,
|
| 547 |
tool_callings=totoal_tool_callings,
|
| 548 |
)
|
| 549 |
+
|
| 550 |
+
if gen_batch_output.meta_info is None:
|
| 551 |
+
gen_batch_output.meta_info = {}
|
| 552 |
+
gen_batch_output.meta_info['rollout_timing'] = rollout_timing
|
| 553 |
+
|
| 554 |
return gen_batch_output
|
poisonclaw/envs/browsergym_env.py
CHANGED
|
@@ -6,6 +6,9 @@ enabling GRPO/GiGPO/PPO training on web environments:
|
|
| 6 |
- VisualWebArena: requires VWA servers (Docker or remote)
|
| 7 |
- WebArena: requires WebArena servers
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
Config fields (under env.*):
|
| 10 |
env_name routing key, must contain "browsergym" (e.g. "browsergym-miniwob")
|
| 11 |
gym_id single BrowserGym task ID (mutually exclusive with task_list)
|
|
@@ -16,6 +19,7 @@ Config fields (under env.*):
|
|
| 16 |
seed base random seed (default: 42)
|
| 17 |
viewport_width screenshot width (default: 1280)
|
| 18 |
viewport_height screenshot height (default: 720)
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
from __future__ import annotations
|
|
@@ -26,6 +30,7 @@ from collections import defaultdict
|
|
| 26 |
from typing import Optional
|
| 27 |
|
| 28 |
import numpy as np
|
|
|
|
| 29 |
|
| 30 |
from agent_system.environments.base import EnvironmentManagerBase
|
| 31 |
|
|
@@ -33,20 +38,74 @@ logger = logging.getLogger(__name__)
|
|
| 33 |
|
| 34 |
# ββ Action regex patterns (coordinate-based, matching VLM output format) βββββ
|
| 35 |
_RE_ACTION_TAG = re.compile(r"<action>(.*?)</action>", re.DOTALL | re.IGNORECASE)
|
| 36 |
-
_RE_CLICK = re.compile(r"click\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)")
|
| 37 |
-
_RE_TYPE = re.compile(r"type\((.+?)\)", re.DOTALL)
|
| 38 |
-
_RE_PRESS = re.compile(r"press\((.+?)\)")
|
| 39 |
-
_RE_NAVIGATE = re.compile(r"(?:navigate|goto)\((.+?)\)", re.IGNORECASE)
|
| 40 |
-
_RE_SCROLL = re.compile(r"scroll\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(.+?)\s*\)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class BrowserGymEnvManager(EnvironmentManagerBase):
|
| 44 |
"""Wraps BrowserGym gym environments for verl-agent training.
|
| 45 |
|
| 46 |
-
Each parallel slot
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
assigned round-robin to the parallel slots.
|
| 50 |
"""
|
| 51 |
|
| 52 |
def __init__(self, config, split: str = "train") -> None:
|
|
@@ -71,26 +130,25 @@ class BrowserGymEnvManager(EnvironmentManagerBase):
|
|
| 71 |
split, self.num_envs, self.task_ids,
|
| 72 |
)
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
self._import_browsergym_namespaces()
|
| 76 |
-
|
| 77 |
-
# Build coordinate-based action mapping (VLM outputs pixel coords,
|
| 78 |
-
# not element bids). Include 'coord' for mouse_click/keyboard_*,
|
| 79 |
-
# 'nav' for goto/go_back, and 'chat' for send_msg_to_user.
|
| 80 |
from browsergym.core.action.highlevel import HighLevelActionSet
|
| 81 |
self._action_set = HighLevelActionSet(subsets=["coord", "nav"])
|
| 82 |
|
| 83 |
-
# Create one
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
action_mapping=self._action_set.to_python_code,
|
|
|
|
| 89 |
)
|
| 90 |
for i in range(self.num_envs)
|
| 91 |
]
|
| 92 |
|
| 93 |
-
# Per-env runtime state
|
| 94 |
self._last_obs: list[Optional[dict]] = [None] * self.num_envs
|
| 95 |
self._steps: list[int] = [0] * self.num_envs
|
| 96 |
self._done: list[bool] = [True] * self.num_envs
|
|
@@ -106,28 +164,87 @@ class BrowserGymEnvManager(EnvironmentManagerBase):
|
|
| 106 |
# ββ EnvironmentManagerBase interface βββββββββββββββββββββββββββββββββββββ
|
| 107 |
|
| 108 |
def reset(self, kwargs=None) -> tuple[dict, list[dict]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
obs_list, info_list = [], []
|
| 110 |
-
for i in
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
obs_list.append(obs)
|
| 113 |
info_list.append(info)
|
|
|
|
| 114 |
return self._pack_obs(obs_list), info_list
|
| 115 |
|
| 116 |
def step(
|
| 117 |
self, text_actions: list[str]
|
| 118 |
) -> tuple[dict, np.ndarray, np.ndarray, list[dict]]:
|
| 119 |
-
obs_list, info_list = [], []
|
| 120 |
rewards = np.zeros(self.num_envs, dtype=np.float32)
|
| 121 |
dones = np.zeros(self.num_envs, dtype=bool)
|
| 122 |
|
|
|
|
|
|
|
|
|
|
| 123 |
for i, action_text in enumerate(text_actions):
|
| 124 |
if self._done[i]:
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
else:
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
rewards[i] = reward
|
| 130 |
dones[i] = done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
obs_list.append(obs)
|
| 132 |
info_list.append(info)
|
| 133 |
|
|
@@ -138,9 +255,13 @@ class BrowserGymEnvManager(EnvironmentManagerBase):
|
|
| 138 |
return self._make_text_obs(obs_list)
|
| 139 |
|
| 140 |
def close(self) -> None:
|
| 141 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
try:
|
| 143 |
-
|
| 144 |
except Exception:
|
| 145 |
pass
|
| 146 |
|
|
@@ -155,71 +276,10 @@ class BrowserGymEnvManager(EnvironmentManagerBase):
|
|
| 155 |
assert len(success["success_rate"]) == batch_size
|
| 156 |
return {k: np.array(v) for k, v in success.items()}
|
| 157 |
|
| 158 |
-
# ββ Internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 159 |
-
|
| 160 |
-
@staticmethod
|
| 161 |
-
def _import_browsergym_namespaces() -> None:
|
| 162 |
-
"""Import BrowserGym sub-packages to register their tasks in gymnasium."""
|
| 163 |
-
_known = {
|
| 164 |
-
"miniwob": "browsergym.miniwob",
|
| 165 |
-
"visualwebarena": "browsergym.visualwebarena",
|
| 166 |
-
"webarena": "browsergym.webarena",
|
| 167 |
-
"workarena": "browsergym.workarena",
|
| 168 |
-
"weblinx": "browsergym.weblinx",
|
| 169 |
-
}
|
| 170 |
-
import importlib
|
| 171 |
-
for key, module in _known.items():
|
| 172 |
-
try:
|
| 173 |
-
importlib.import_module(module)
|
| 174 |
-
except ImportError:
|
| 175 |
-
pass # optional package not installed
|
| 176 |
-
|
| 177 |
-
def _reset_env(self, idx: int) -> tuple[dict, dict]:
|
| 178 |
-
seed = self._seeds[idx]
|
| 179 |
-
self._seeds[idx] += self.num_envs # advance seed for next episode
|
| 180 |
-
|
| 181 |
-
obs, info = self._gym_envs[idx].reset(seed=seed)
|
| 182 |
-
self._last_obs[idx] = obs
|
| 183 |
-
self._steps[idx] = 0
|
| 184 |
-
self._done[idx] = False
|
| 185 |
-
self._history[idx] = []
|
| 186 |
-
self._goals[idx] = self._extract_goal(obs)
|
| 187 |
-
|
| 188 |
-
info.setdefault("won", False)
|
| 189 |
-
info["is_action_valid"] = np.array(True)
|
| 190 |
-
return obs, info
|
| 191 |
-
|
| 192 |
-
def _step_env(
|
| 193 |
-
self, idx: int, action_text: str
|
| 194 |
-
) -> tuple[dict, float, bool, dict]:
|
| 195 |
-
bg_action, is_valid = self._parse_action(action_text)
|
| 196 |
-
|
| 197 |
-
obs, reward, terminated, truncated, info = self._gym_envs[idx].step(bg_action)
|
| 198 |
-
|
| 199 |
-
self._last_obs[idx] = obs
|
| 200 |
-
self._steps[idx] += 1
|
| 201 |
-
done = terminated or truncated or (self._steps[idx] >= self.max_steps)
|
| 202 |
-
self._done[idx] = done
|
| 203 |
-
|
| 204 |
-
self._history[idx].append(action_text)
|
| 205 |
-
|
| 206 |
-
info["won"] = bool(terminated and reward > 0)
|
| 207 |
-
info["is_action_valid"] = np.array(is_valid)
|
| 208 |
-
info["last_action_error"] = obs.get("last_action_error", "")
|
| 209 |
-
return obs, float(reward), done, info
|
| 210 |
-
|
| 211 |
# ββ Action parsing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
|
| 213 |
def _parse_action(self, text: str) -> tuple[str, bool]:
|
| 214 |
-
"""Convert VLM text output β BrowserGym coordinate-based action string.
|
| 215 |
-
|
| 216 |
-
BrowserGym uses its own action API (not raw Playwright calls):
|
| 217 |
-
- mouse_click(x, y) for coordinate clicks
|
| 218 |
-
- keyboard_type(text) for typing
|
| 219 |
-
- keyboard_press(key_comb) for key presses
|
| 220 |
-
- goto(url) for navigation
|
| 221 |
-
- scroll(dx, dy) for scrolling
|
| 222 |
-
"""
|
| 223 |
# Unwrap optional <action> tags
|
| 224 |
m = _RE_ACTION_TAG.search(text)
|
| 225 |
text = m.group(1).strip() if m else text.strip()
|
|
@@ -240,6 +300,7 @@ class BrowserGymEnvManager(EnvironmentManagerBase):
|
|
| 240 |
m = _RE_PRESS.search(text)
|
| 241 |
if m:
|
| 242 |
key = m.group(1).strip().strip("\"'")
|
|
|
|
| 243 |
return f'keyboard_press("{key}")', True
|
| 244 |
|
| 245 |
# navigate(url) / goto(url)
|
|
@@ -288,9 +349,12 @@ class BrowserGymEnvManager(EnvironmentManagerBase):
|
|
| 288 |
if err:
|
| 289 |
parts.append(f"Last action error: {err}")
|
| 290 |
parts.append(
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
"
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
texts.append("\n\n".join(parts))
|
| 296 |
return texts
|
|
|
|
| 6 |
- VisualWebArena: requires VWA servers (Docker or remote)
|
| 7 |
- WebArena: requires WebArena servers
|
| 8 |
|
| 9 |
+
Each BrowserGym env runs in its own Ray Actor process, so all envs
|
| 10 |
+
step/reset in parallel (no GIL or Playwright thread-affinity issues).
|
| 11 |
+
|
| 12 |
Config fields (under env.*):
|
| 13 |
env_name routing key, must contain "browsergym" (e.g. "browsergym-miniwob")
|
| 14 |
gym_id single BrowserGym task ID (mutually exclusive with task_list)
|
|
|
|
| 19 |
seed base random seed (default: 42)
|
| 20 |
viewport_width screenshot width (default: 1280)
|
| 21 |
viewport_height screenshot height (default: 720)
|
| 22 |
+
pre_observation_delay seconds to wait before obs extraction (default: 0.5)
|
| 23 |
"""
|
| 24 |
|
| 25 |
from __future__ import annotations
|
|
|
|
| 30 |
from typing import Optional
|
| 31 |
|
| 32 |
import numpy as np
|
| 33 |
+
import ray
|
| 34 |
|
| 35 |
from agent_system.environments.base import EnvironmentManagerBase
|
| 36 |
|
|
|
|
| 38 |
|
| 39 |
# ββ Action regex patterns (coordinate-based, matching VLM output format) βββββ
|
| 40 |
_RE_ACTION_TAG = re.compile(r"<action>(.*?)</action>", re.DOTALL | re.IGNORECASE)
|
| 41 |
+
_RE_CLICK = re.compile(r"click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)", re.IGNORECASE)
|
| 42 |
+
_RE_TYPE = re.compile(r"type\s*\(\s*(.+?)\s*\)", re.DOTALL)
|
| 43 |
+
_RE_PRESS = re.compile(r"press\s*\(\s*(.+?)\s*\)", re.IGNORECASE)
|
| 44 |
+
_RE_NAVIGATE = re.compile(r"(?:navigate|goto)\s*\(\s*(.+?)\s*\)", re.IGNORECASE)
|
| 45 |
+
_RE_SCROLL = re.compile(r"scroll\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(.+?)\s*\)")
|
| 46 |
+
|
| 47 |
+
# Playwright key name mapping (VLM may output lowercase)
|
| 48 |
+
_KEY_MAP = {
|
| 49 |
+
"enter": "Enter", "tab": "Tab", "escape": "Escape", "esc": "Escape",
|
| 50 |
+
"backspace": "Backspace", "delete": "Delete", "space": " ",
|
| 51 |
+
"arrowup": "ArrowUp", "arrowdown": "ArrowDown",
|
| 52 |
+
"arrowleft": "ArrowLeft", "arrowright": "ArrowRight",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ββ Ray Actor: one BrowserGym env per process ββββββββββββββββββββββββββββββββ
|
| 57 |
+
|
| 58 |
+
class BrowserGymWorker:
|
| 59 |
+
"""Ray Actor wrapping a single BrowserGym gymnasium env.
|
| 60 |
+
|
| 61 |
+
Runs in its own process β no GIL or Playwright thread issues.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, task_id: str, action_mapping, pre_obs_delay: float = 0.5):
|
| 65 |
+
import gymnasium as gym
|
| 66 |
+
self._import_browsergym_namespaces()
|
| 67 |
+
self.env = gym.make(
|
| 68 |
+
task_id,
|
| 69 |
+
action_mapping=action_mapping,
|
| 70 |
+
pre_observation_delay=pre_obs_delay,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def step(self, action: str):
|
| 74 |
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
| 75 |
+
return obs, float(reward), terminated, truncated, info
|
| 76 |
+
|
| 77 |
+
def reset(self, seed: int):
|
| 78 |
+
obs, info = self.env.reset(seed=seed)
|
| 79 |
+
return obs, info
|
| 80 |
+
|
| 81 |
+
def close(self):
|
| 82 |
+
try:
|
| 83 |
+
self.env.close()
|
| 84 |
+
except Exception:
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def _import_browsergym_namespaces():
|
| 89 |
+
import importlib
|
| 90 |
+
for module in [
|
| 91 |
+
"browsergym.miniwob",
|
| 92 |
+
"browsergym.visualwebarena",
|
| 93 |
+
"browsergym.webarena",
|
| 94 |
+
"browsergym.workarena",
|
| 95 |
+
"browsergym.weblinx",
|
| 96 |
+
]:
|
| 97 |
+
try:
|
| 98 |
+
importlib.import_module(module)
|
| 99 |
+
except ImportError:
|
| 100 |
+
pass
|
| 101 |
|
| 102 |
|
| 103 |
class BrowserGymEnvManager(EnvironmentManagerBase):
|
| 104 |
"""Wraps BrowserGym gym environments for verl-agent training.
|
| 105 |
|
| 106 |
+
Each parallel slot is a Ray Actor running a BrowserGym env in its own
|
| 107 |
+
process. ``step()`` and ``reset()`` dispatch to all actors in parallel
|
| 108 |
+
via ``ray.get([actor.step.remote(...) for ...])``.
|
|
|
|
| 109 |
"""
|
| 110 |
|
| 111 |
def __init__(self, config, split: str = "train") -> None:
|
|
|
|
| 130 |
split, self.num_envs, self.task_ids,
|
| 131 |
)
|
| 132 |
|
| 133 |
+
# Build coordinate-based action mapping
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
from browsergym.core.action.highlevel import HighLevelActionSet
|
| 135 |
self._action_set = HighLevelActionSet(subsets=["coord", "nav"])
|
| 136 |
|
| 137 |
+
# Create Ray Actor workers (one browser per actor, fully parallel)
|
| 138 |
+
pre_obs_delay = float(getattr(config.env, "pre_observation_delay", 0.5))
|
| 139 |
+
resources = {"num_cpus": config.env.resources_per_worker.get("num_cpus", 0.5)}
|
| 140 |
+
WorkerActor = ray.remote(**resources)(BrowserGymWorker)
|
| 141 |
+
|
| 142 |
+
self._workers = [
|
| 143 |
+
WorkerActor.remote(
|
| 144 |
+
task_id=self.task_ids[i % len(self.task_ids)],
|
| 145 |
action_mapping=self._action_set.to_python_code,
|
| 146 |
+
pre_obs_delay=pre_obs_delay,
|
| 147 |
)
|
| 148 |
for i in range(self.num_envs)
|
| 149 |
]
|
| 150 |
|
| 151 |
+
# Per-env runtime state (kept on manager side for obs building)
|
| 152 |
self._last_obs: list[Optional[dict]] = [None] * self.num_envs
|
| 153 |
self._steps: list[int] = [0] * self.num_envs
|
| 154 |
self._done: list[bool] = [True] * self.num_envs
|
|
|
|
| 164 |
# ββ EnvironmentManagerBase interface βββββββββββββββββββββββββββββββββββββ
|
| 165 |
|
| 166 |
def reset(self, kwargs=None) -> tuple[dict, list[dict]]:
|
| 167 |
+
futures = [
|
| 168 |
+
self._workers[i].reset.remote(self._seeds[i])
|
| 169 |
+
for i in range(self.num_envs)
|
| 170 |
+
]
|
| 171 |
+
results = ray.get(futures)
|
| 172 |
+
|
| 173 |
obs_list, info_list = [], []
|
| 174 |
+
for i, (obs, info) in enumerate(results):
|
| 175 |
+
self._last_obs[i] = obs
|
| 176 |
+
self._steps[i] = 0
|
| 177 |
+
self._done[i] = False
|
| 178 |
+
self._history[i] = []
|
| 179 |
+
self._goals[i] = self._extract_goal(obs)
|
| 180 |
+
self._seeds[i] += self.num_envs
|
| 181 |
+
info.setdefault("won", False)
|
| 182 |
+
info["is_action_valid"] = np.array(True)
|
| 183 |
obs_list.append(obs)
|
| 184 |
info_list.append(info)
|
| 185 |
+
|
| 186 |
return self._pack_obs(obs_list), info_list
|
| 187 |
|
| 188 |
def step(
|
| 189 |
self, text_actions: list[str]
|
| 190 |
) -> tuple[dict, np.ndarray, np.ndarray, list[dict]]:
|
|
|
|
| 191 |
rewards = np.zeros(self.num_envs, dtype=np.float32)
|
| 192 |
dones = np.zeros(self.num_envs, dtype=bool)
|
| 193 |
|
| 194 |
+
# Dispatch step or reset to each worker in parallel
|
| 195 |
+
futures = []
|
| 196 |
+
action_map = {} # track which envs are stepping (vs resetting)
|
| 197 |
for i, action_text in enumerate(text_actions):
|
| 198 |
if self._done[i]:
|
| 199 |
+
seed = self._seeds[i]
|
| 200 |
+
self._seeds[i] += self.num_envs
|
| 201 |
+
futures.append(self._workers[i].reset.remote(seed))
|
| 202 |
+
action_map[i] = "reset"
|
| 203 |
+
else:
|
| 204 |
+
bg_action, is_valid = self._parse_action(action_text)
|
| 205 |
+
futures.append(self._workers[i].step.remote(bg_action))
|
| 206 |
+
action_map[i] = ("step", action_text, is_valid)
|
| 207 |
+
|
| 208 |
+
results = ray.get(futures)
|
| 209 |
+
|
| 210 |
+
obs_list, info_list = [], []
|
| 211 |
+
for i, result in enumerate(results):
|
| 212 |
+
if action_map[i] == "reset":
|
| 213 |
+
obs, info = result
|
| 214 |
+
self._last_obs[i] = obs
|
| 215 |
+
self._steps[i] = 0
|
| 216 |
+
self._done[i] = False
|
| 217 |
+
self._history[i] = []
|
| 218 |
+
self._goals[i] = self._extract_goal(obs)
|
| 219 |
+
info.setdefault("won", False)
|
| 220 |
+
info["is_action_valid"] = np.array(True)
|
| 221 |
+
dones[i] = False
|
| 222 |
else:
|
| 223 |
+
_, action_text, is_valid = action_map[i]
|
| 224 |
+
obs, reward, terminated, truncated, info = result
|
| 225 |
+
self._last_obs[i] = obs
|
| 226 |
+
self._steps[i] += 1
|
| 227 |
+
done = terminated or truncated or (self._steps[i] >= self.max_steps)
|
| 228 |
+
self._done[i] = done
|
| 229 |
+
self._history[i].append(action_text)
|
| 230 |
rewards[i] = reward
|
| 231 |
dones[i] = done
|
| 232 |
+
info["won"] = bool(terminated and reward > 0)
|
| 233 |
+
info["is_action_valid"] = np.array(is_valid)
|
| 234 |
+
info["last_action_error"] = obs.get("last_action_error", "")
|
| 235 |
+
|
| 236 |
+
# Debug: log first env's actions for the first few steps
|
| 237 |
+
if i == 0 and self._steps[i] <= 3:
|
| 238 |
+
import sys
|
| 239 |
+
err = obs.get("last_action_error", "")
|
| 240 |
+
bg_action = action_text # approximate for logging
|
| 241 |
+
print(
|
| 242 |
+
f"[DEBUG env0 step{self._steps[i]}] "
|
| 243 |
+
f"vlm={action_text[:80]!r} "
|
| 244 |
+
f"valid={is_valid} r={reward} term={terminated} err={err!r}",
|
| 245 |
+
file=sys.stderr, flush=True,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
obs_list.append(obs)
|
| 249 |
info_list.append(info)
|
| 250 |
|
|
|
|
| 255 |
return self._make_text_obs(obs_list)
|
| 256 |
|
| 257 |
def close(self) -> None:
|
| 258 |
+
for worker in self._workers:
|
| 259 |
+
try:
|
| 260 |
+
ray.get(worker.close.remote())
|
| 261 |
+
except Exception:
|
| 262 |
+
pass
|
| 263 |
try:
|
| 264 |
+
ray.kill(worker)
|
| 265 |
except Exception:
|
| 266 |
pass
|
| 267 |
|
|
|
|
| 276 |
assert len(success["success_rate"]) == batch_size
|
| 277 |
return {k: np.array(v) for k, v in success.items()}
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
# ββ Action parsing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 280 |
|
| 281 |
def _parse_action(self, text: str) -> tuple[str, bool]:
|
| 282 |
+
"""Convert VLM text output β BrowserGym coordinate-based action string."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
# Unwrap optional <action> tags
|
| 284 |
m = _RE_ACTION_TAG.search(text)
|
| 285 |
text = m.group(1).strip() if m else text.strip()
|
|
|
|
| 300 |
m = _RE_PRESS.search(text)
|
| 301 |
if m:
|
| 302 |
key = m.group(1).strip().strip("\"'")
|
| 303 |
+
key = _KEY_MAP.get(key.lower(), key)
|
| 304 |
return f'keyboard_press("{key}")', True
|
| 305 |
|
| 306 |
# navigate(url) / goto(url)
|
|
|
|
| 349 |
if err:
|
| 350 |
parts.append(f"Last action error: {err}")
|
| 351 |
parts.append(
|
| 352 |
+
"Respond with exactly ONE action using the formats below. "
|
| 353 |
+
"Replace the placeholders with actual values.\n"
|
| 354 |
+
" click(x, y) β click at pixel coordinates, e.g. click(120, 55)\n"
|
| 355 |
+
" type(text) β type a string, e.g. type(hello world)\n"
|
| 356 |
+
" press(key) β press a key, e.g. press(Enter)\n"
|
| 357 |
+
"Your response must start with the action, nothing else."
|
| 358 |
)
|
| 359 |
texts.append("\n\n".join(parts))
|
| 360 |
return texts
|
scripts/run_browsergym_miniwob.sh
CHANGED
|
@@ -22,6 +22,9 @@ export MINIWOB_URL="http://localhost:${MINIWOB_PORT}/miniwob/"
|
|
| 22 |
CONDA_ENV_LIB="$(python3 -c 'import sys, os; print(os.path.join(sys.prefix, "lib"))')"
|
| 23 |
LOCAL_LIBS="/home/jovyan/project/verl-agent/local-libs/extracted/usr/lib/x86_64-linux-gnu"
|
| 24 |
export LD_LIBRARY_PATH="${CONDA_ENV_LIB}:${LOCAL_LIBS}:${LD_LIBRARY_PATH:-}"
|
|
|
|
|
|
|
|
|
|
| 25 |
echo "[run_browsergym_miniwob] MINIWOB_URL=$MINIWOB_URL (pid=$HTTP_PID)"
|
| 26 |
|
| 27 |
# Cleanup HTTP server on exit
|
|
@@ -29,9 +32,9 @@ cleanup() { kill "$HTTP_PID" 2>/dev/null || true; }
|
|
| 29 |
trap cleanup EXIT
|
| 30 |
|
| 31 |
# ββ Tunable knobs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
-
train_data_size=
|
| 33 |
-
val_data_size=
|
| 34 |
-
group_size=
|
| 35 |
HF_MODEL_ID="Qwen/Qwen2.5-VL-3B-Instruct"
|
| 36 |
HF_CACHE_SNAPSHOT="$HOME/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3"
|
| 37 |
LOCAL_MODEL_PATH="/tmp/Qwen2.5-VL-3B-Instruct"
|
|
@@ -92,14 +95,16 @@ fi
|
|
| 92 |
# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 93 |
python3 -m verl.trainer.main_ppo \
|
| 94 |
algorithm.adv_estimator=grpo \
|
| 95 |
-
algorithm.use_kl_in_reward=
|
| 96 |
-
algorithm.
|
|
|
|
|
|
|
| 97 |
\
|
| 98 |
data.train_files="$HOME/data/verl-agent/visual/train.parquet" \
|
| 99 |
data.val_files="$HOME/data/verl-agent/visual/test.parquet" \
|
| 100 |
data.train_batch_size="$train_data_size" \
|
| 101 |
data.val_batch_size="$val_data_size" \
|
| 102 |
-
data.max_prompt_length=
|
| 103 |
data.max_response_length=256 \
|
| 104 |
data.filter_overlong_prompts=True \
|
| 105 |
data.truncation=left \
|
|
@@ -115,7 +120,7 @@ python3 -m verl.trainer.main_ppo \
|
|
| 115 |
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 116 |
actor_rollout_ref.model.use_remove_padding=False \
|
| 117 |
actor_rollout_ref.actor.strategy=fsdp \
|
| 118 |
-
actor_rollout_ref.actor.optim.lr=
|
| 119 |
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
|
| 120 |
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
| 121 |
actor_rollout_ref.actor.use_kl_loss=False \
|
|
@@ -126,37 +131,39 @@ python3 -m verl.trainer.main_ppo \
|
|
| 126 |
actor_rollout_ref.rollout.name="$ENGINE" \
|
| 127 |
actor_rollout_ref.rollout.n=1 \
|
| 128 |
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 129 |
-
actor_rollout_ref.rollout.gpu_memory_utilization=0.
|
| 130 |
-
actor_rollout_ref.rollout.enable_chunked_prefill=
|
| 131 |
actor_rollout_ref.rollout.enforce_eager=False \
|
| 132 |
actor_rollout_ref.rollout.free_cache_engine=False \
|
| 133 |
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
| 134 |
actor_rollout_ref.rollout.val_kwargs.temperature=0.0 \
|
| 135 |
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
| 136 |
-
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=
|
| 137 |
actor_rollout_ref.ref.fsdp_config.param_offload=False \
|
| 138 |
"+ray_init.runtime_env.env_vars.LD_LIBRARY_PATH=${CONDA_ENV_LIB}:${LOCAL_LIBS}" \
|
|
|
|
| 139 |
\
|
| 140 |
\
|
| 141 |
env.env_name=browsergym-miniwob \
|
| 142 |
env.seed=42 \
|
| 143 |
-
env.max_steps=
|
| 144 |
env.rollout.n="$group_size" \
|
| 145 |
-
env.resources_per_worker.num_cpus=0.
|
| 146 |
++env.history_length=3 \
|
| 147 |
-
++env.task_list="[browsergym/miniwob.click-
|
|
|
|
| 148 |
++env.viewport_width=332 \
|
| 149 |
++env.viewport_height=214 \
|
| 150 |
\
|
| 151 |
trainer.critic_warmup=0 \
|
| 152 |
-
trainer.logger="[console]" \
|
| 153 |
trainer.project_name="$project_name" \
|
| 154 |
trainer.experiment_name="$experiment_name" \
|
| 155 |
trainer.n_gpus_per_node=1 \
|
| 156 |
trainer.nnodes=1 \
|
| 157 |
trainer.save_freq=100 \
|
| 158 |
trainer.test_freq=-1 \
|
| 159 |
-
trainer.total_epochs=
|
| 160 |
-
trainer.val_before_train=
|
| 161 |
+ray_init.include_dashboard=False \
|
| 162 |
"$@"
|
|
|
|
| 22 |
CONDA_ENV_LIB="$(python3 -c 'import sys, os; print(os.path.join(sys.prefix, "lib"))')"
|
| 23 |
LOCAL_LIBS="/home/jovyan/project/verl-agent/local-libs/extracted/usr/lib/x86_64-linux-gnu"
|
| 24 |
export LD_LIBRARY_PATH="${CONDA_ENV_LIB}:${LOCAL_LIBS}:${LD_LIBRARY_PATH:-}"
|
| 25 |
+
export WANDB__SERVICE_WAIT=120
|
| 26 |
+
WANDB_API_KEY=$(python3 -c "import wandb; print(wandb.api.api_key)" 2>/dev/null)
|
| 27 |
+
export WANDB_API_KEY
|
| 28 |
echo "[run_browsergym_miniwob] MINIWOB_URL=$MINIWOB_URL (pid=$HTTP_PID)"
|
| 29 |
|
| 30 |
# Cleanup HTTP server on exit
|
|
|
|
| 32 |
trap cleanup EXIT
|
| 33 |
|
| 34 |
# ββ Tunable knobs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
train_data_size=4 # parallel train envs (= train_batch_size)
|
| 36 |
+
val_data_size=32 # parallel val envs
|
| 37 |
+
group_size=8 # GRPO group size (rollout.n)
|
| 38 |
HF_MODEL_ID="Qwen/Qwen2.5-VL-3B-Instruct"
|
| 39 |
HF_CACHE_SNAPSHOT="$HOME/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3"
|
| 40 |
LOCAL_MODEL_PATH="/tmp/Qwen2.5-VL-3B-Instruct"
|
|
|
|
| 95 |
# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
python3 -m verl.trainer.main_ppo \
|
| 97 |
algorithm.adv_estimator=grpo \
|
| 98 |
+
algorithm.use_kl_in_reward=True \
|
| 99 |
+
algorithm.kl_ctrl.type=fixed \
|
| 100 |
+
algorithm.kl_ctrl.kl_coef=0.001 \
|
| 101 |
+
algorithm.gamma=0.95 \
|
| 102 |
\
|
| 103 |
data.train_files="$HOME/data/verl-agent/visual/train.parquet" \
|
| 104 |
data.val_files="$HOME/data/verl-agent/visual/test.parquet" \
|
| 105 |
data.train_batch_size="$train_data_size" \
|
| 106 |
data.val_batch_size="$val_data_size" \
|
| 107 |
+
data.max_prompt_length=1024 \
|
| 108 |
data.max_response_length=256 \
|
| 109 |
data.filter_overlong_prompts=True \
|
| 110 |
data.truncation=left \
|
|
|
|
| 120 |
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 121 |
actor_rollout_ref.model.use_remove_padding=False \
|
| 122 |
actor_rollout_ref.actor.strategy=fsdp \
|
| 123 |
+
actor_rollout_ref.actor.optim.lr=2e-5 \
|
| 124 |
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
|
| 125 |
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
| 126 |
actor_rollout_ref.actor.use_kl_loss=False \
|
|
|
|
| 131 |
actor_rollout_ref.rollout.name="$ENGINE" \
|
| 132 |
actor_rollout_ref.rollout.n=1 \
|
| 133 |
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 134 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \
|
| 135 |
+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
| 136 |
actor_rollout_ref.rollout.enforce_eager=False \
|
| 137 |
actor_rollout_ref.rollout.free_cache_engine=False \
|
| 138 |
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
| 139 |
actor_rollout_ref.rollout.val_kwargs.temperature=0.0 \
|
| 140 |
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
| 141 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
|
| 142 |
actor_rollout_ref.ref.fsdp_config.param_offload=False \
|
| 143 |
"+ray_init.runtime_env.env_vars.LD_LIBRARY_PATH=${CONDA_ENV_LIB}:${LOCAL_LIBS}" \
|
| 144 |
+
"+ray_init.runtime_env.env_vars.WANDB_API_KEY=${WANDB_API_KEY}" \
|
| 145 |
\
|
| 146 |
\
|
| 147 |
env.env_name=browsergym-miniwob \
|
| 148 |
env.seed=42 \
|
| 149 |
+
env.max_steps=7 \
|
| 150 |
env.rollout.n="$group_size" \
|
| 151 |
+
env.resources_per_worker.num_cpus=0.5 \
|
| 152 |
++env.history_length=3 \
|
| 153 |
+
++env.task_list="[browsergym/miniwob.click-checkboxes,browsergym/miniwob.click-tab-2,browsergym/miniwob.email-inbox,browsergym/miniwob.search-engine,browsergym/miniwob.login-user,browsergym/miniwob.social-media,browsergym/miniwob.click-collapsible-2,browsergym/miniwob.book-flight]" \
|
| 154 |
+
++env.pre_observation_delay=0.1 \
|
| 155 |
++env.viewport_width=332 \
|
| 156 |
++env.viewport_height=214 \
|
| 157 |
\
|
| 158 |
trainer.critic_warmup=0 \
|
| 159 |
+
trainer.logger="[console,wandb]" \
|
| 160 |
trainer.project_name="$project_name" \
|
| 161 |
trainer.experiment_name="$experiment_name" \
|
| 162 |
trainer.n_gpus_per_node=1 \
|
| 163 |
trainer.nnodes=1 \
|
| 164 |
trainer.save_freq=100 \
|
| 165 |
trainer.test_freq=-1 \
|
| 166 |
+
trainer.total_epochs=50 \
|
| 167 |
+
trainer.val_before_train=False \
|
| 168 |
+ray_init.include_dashboard=False \
|
| 169 |
"$@"
|
verl/trainer/ppo/ray_trainer.py
CHANGED
|
@@ -410,7 +410,8 @@ def _print_timing_breakdown(timing_raw: Dict[str, float], global_step: int):
|
|
| 410 |
other = total - accounted
|
| 411 |
if other > 0.5:
|
| 412 |
parts.append(f"Other: {other:.1f}s ({other / total * 100:.0f}%)")
|
| 413 |
-
|
|
|
|
| 414 |
|
| 415 |
|
| 416 |
class RayPPOTrainer:
|
|
@@ -1062,7 +1063,12 @@ class RayPPOTrainer:
|
|
| 1062 |
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
| 1063 |
val_metrics = self._validate()
|
| 1064 |
assert val_metrics, f"{val_metrics=}"
|
| 1065 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1066 |
logger.log(data=val_metrics, step=self.global_steps)
|
| 1067 |
if self.config.trainer.get("val_only", False):
|
| 1068 |
return
|
|
@@ -1115,6 +1121,14 @@ class RayPPOTrainer:
|
|
| 1115 |
envs=self.envs,
|
| 1116 |
is_train=True,
|
| 1117 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
| 1119 |
with _timer("gen_max", timing_raw):
|
| 1120 |
gen_baseline_batch = deepcopy(gen_batch)
|
|
|
|
| 410 |
other = total - accounted
|
| 411 |
if other > 0.5:
|
| 412 |
parts.append(f"Other: {other:.1f}s ({other / total * 100:.0f}%)")
|
| 413 |
+
import sys
|
| 414 |
+
print(f"\n[Step {global_step}] Total: {total:.1f}s | {' | '.join(parts)}\n", file=sys.stderr, flush=True)
|
| 415 |
|
| 416 |
|
| 417 |
class RayPPOTrainer:
|
|
|
|
| 1063 |
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
| 1064 |
val_metrics = self._validate()
|
| 1065 |
assert val_metrics, f"{val_metrics=}"
|
| 1066 |
+
import sys
|
| 1067 |
+
print("\n" + "=" * 60, file=sys.stderr, flush=True)
|
| 1068 |
+
print("INITIAL VALIDATION METRICS:", file=sys.stderr, flush=True)
|
| 1069 |
+
for k, v in val_metrics.items():
|
| 1070 |
+
print(f" {k}: {v:.4f}", file=sys.stderr, flush=True)
|
| 1071 |
+
print("=" * 60 + "\n", file=sys.stderr, flush=True)
|
| 1072 |
logger.log(data=val_metrics, step=self.global_steps)
|
| 1073 |
if self.config.trainer.get("val_only", False):
|
| 1074 |
return
|
|
|
|
| 1121 |
envs=self.envs,
|
| 1122 |
is_train=True,
|
| 1123 |
)
|
| 1124 |
+
# Extract rollout sub-timing into timing_raw
|
| 1125 |
+
_rt = getattr(gen_batch_output, 'meta_info', None) or {}
|
| 1126 |
+
_rt = _rt.get('rollout_timing', {})
|
| 1127 |
+
if _rt:
|
| 1128 |
+
timing_raw['gen_preprocess'] = _rt.get('preprocess_s', 0.0)
|
| 1129 |
+
timing_raw['gen_inference'] = _rt.get('inference_s', 0.0)
|
| 1130 |
+
timing_raw['gen_env'] = _rt.get('env_s', 0.0)
|
| 1131 |
+
|
| 1132 |
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
| 1133 |
with _timer("gen_max", timing_raw):
|
| 1134 |
gen_baseline_batch = deepcopy(gen_batch)
|
verl/utils/logger/aggregate_logger.py
CHANGED
|
@@ -40,7 +40,8 @@ class LocalLogger:
|
|
| 40 |
|
| 41 |
def log(self, data, step):
|
| 42 |
if self.print_to_console:
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
class DecoratorLoggerBase:
|
|
|
|
| 40 |
|
| 41 |
def log(self, data, step):
|
| 42 |
if self.print_to_console:
|
| 43 |
+
import sys
|
| 44 |
+
print(concat_dict_to_str(data, step=step), file=sys.stderr, flush=True)
|
| 45 |
|
| 46 |
|
| 47 |
class DecoratorLoggerBase:
|
verl/utils/tracking.py
CHANGED
|
@@ -51,7 +51,12 @@ class Tracking:
|
|
| 51 |
if "tracking" in default_backend or "wandb" in default_backend:
|
| 52 |
import wandb
|
| 53 |
|
| 54 |
-
wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
self.logger["wandb"] = wandb
|
| 56 |
|
| 57 |
if "mlflow" in default_backend:
|
|
|
|
| 51 |
if "tracking" in default_backend or "wandb" in default_backend:
|
| 52 |
import wandb
|
| 53 |
|
| 54 |
+
wandb.init(
|
| 55 |
+
project=project_name,
|
| 56 |
+
name=experiment_name,
|
| 57 |
+
config=config,
|
| 58 |
+
settings=wandb.Settings(start_method="thread"),
|
| 59 |
+
)
|
| 60 |
self.logger["wandb"] = wandb
|
| 61 |
|
| 62 |
if "mlflow" in default_backend:
|