# Copyright 2025 Nanyang Technological University (NTU), Singapore # and the verl-agent (GiGPO) team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import numpy as np from verl import DataProto from verl.utils.dataset.rl_dataset import collate_fn from verl.utils.model import compute_position_id_with_mask import verl.utils.torch_functional as verl_F from transformers import PreTrainedTokenizer import uuid from agent_system.multi_turn_rollout.utils import process_image, to_list_of_dict, torch_to_numpy, filter_group_data from agent_system.environments import EnvironmentManagerBase from typing import List, Dict from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto class TrajectoryCollector: def __init__(self, config, tokenizer: PreTrainedTokenizer, processor=None): """ Initialize the TrajectoryProcessor class. Parameters: config: Configuration object containing data processing settings tokenizer (PreTrainedTokenizer): Tokenizer for text encoding and decoding processor: Image processor for multimodal inputs """ self.config = config self.tokenizer = tokenizer self.processor = processor def preprocess_single_sample( self, item: int, gen_batch: DataProto, obs: Dict, ): """ Process a single observation sample, organizing environment observations (text and/or images) into a format processable by the model. Parameters: item (int): Sample index in the batch gen_batch (DataProto): Batch data containing original prompts obs (Dict): Environment observation, may contain 'text', 'image', 'anchor' keys Returns: dict: Contains processed input data such as input_ids, attention_mask, etc. """ raw_prompt = gen_batch.non_tensor_batch['raw_prompt'][item] data_source = gen_batch.non_tensor_batch['data_source'][item] apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {}) # Get observation components obs_texts = obs.get('text', None) obs_images = obs.get('image', None) obs_anchors = obs.get('anchor', None) obs_text = obs_texts[item] if obs_texts is not None else None obs_image = obs_images[item] if obs_images is not None else None obs_anchor = obs_anchors[item] if obs_anchors is not None else None is_multi_modal = obs_image is not None _obs_anchor = torch_to_numpy(obs_anchor, is_object=True) if isinstance(obs_anchor, torch.Tensor) else obs_anchor # Build chat structure # obs_content = raw_prompt[0]['content'] # if '' in obs_content: # obs_content = obs_content.replace('', '') # Build chat structure obs_content = '' if obs_text is not None: obs_content += obs_text else: print(f"Warning: No text observation found!") chat = np.array([{ "content": obs_content, "role": "user", }]) # Apply chat template prompt_with_chat_template = self.tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False, **apply_chat_template_kwargs ) # Initialize return dict row_dict = {} # Process multimodal data if is_multi_modal: # Replace image placeholder with vision tokens raw_prompt = prompt_with_chat_template.replace('', '<|vision_start|><|image_pad|><|vision_end|>') row_dict['multi_modal_data'] = {'image': [process_image(obs_image)]} image_inputs = self.processor.image_processor(row_dict['multi_modal_data']['image'], return_tensors='pt') image_grid_thw = image_inputs['image_grid_thw'] row_dict['multi_modal_inputs'] = {key: val for key, val in image_inputs.items()} if image_grid_thw is not None: merge_length = self.processor.image_processor.merge_size**2 index = 0 while '' in prompt_with_chat_template: prompt_with_chat_template = prompt_with_chat_template.replace( '', '<|vision_start|>' + '<|placeholder|>' * (image_grid_thw[index].prod() // merge_length) + '<|vision_end|>', 1, ) index += 1 prompt_with_chat_template = prompt_with_chat_template.replace('<|placeholder|>', self.processor.image_token) else: raw_prompt = prompt_with_chat_template input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template, tokenizer=self.tokenizer, max_length=self.config.data.max_prompt_length, pad_token_id=self.tokenizer.pad_token_id, left_pad=True, truncation=self.config.data.truncation,) if is_multi_modal: if "Qwen3VLProcessor" in self.processor.__class__.__name__: from verl.models.transformers.qwen3_vl import get_rope_index else: from verl.models.transformers.qwen2_vl import get_rope_index vision_position_ids = get_rope_index( self.processor, input_ids=input_ids[0], image_grid_thw=image_grid_thw, attention_mask=attention_mask[0], ) # (3, seq_length) valid_mask = attention_mask[0].bool() text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long) text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item()) position_ids = [torch.cat((text_position_ids, vision_position_ids), dim=0)] # (1, 4, seq_length) else: position_ids = compute_position_id_with_mask(attention_mask) raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) if len(raw_prompt_ids) > self.config.data.max_prompt_length: if self.config.data.truncation == "left": raw_prompt_ids = raw_prompt_ids[-self.config.data.max_prompt_length :] elif self.config.data.truncation == "right": raw_prompt_ids = raw_prompt_ids[: self.config.data.max_prompt_length] elif self.config.data.truncation == "middle": left_half = self.config.data.max_prompt_length // 2 right_half = self.config.data.max_prompt_length - left_half raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] elif self.config.data.truncation == "error": raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.config.data.max_prompt_length}.") # Build final output dict row_dict.update({ 'input_ids': input_ids[0], 'attention_mask': attention_mask[0], 'position_ids': position_ids[0], 'raw_prompt_ids': raw_prompt_ids, 'anchor_obs': _obs_anchor, 'index': item, 'data_source': data_source }) if self.config.data.get('return_raw_chat', False): row_dict['raw_prompt'] = chat.tolist() return row_dict def preprocess_batch( self, gen_batch: DataProto, obs: Dict, ) -> DataProto: """ Process a batch of observation samples, converting environment observations into model-processable format. Parameters: gen_batch (DataProto): Batch data containing original prompts obs (Dict): Environment observation dictionary - 'text' (None or List[str]): Text observation data - 'image' (np.ndarray or torch.Tensor): Image observation data - 'anchor' (None or Any): Anchor observation without any histories or additional info. (for GiGPO only). Returns: DataProto: Contains processed batch data with preserved metadata """ batch_size = len(gen_batch.batch['input_ids']) processed_samples = [] # Process each sample in parallel for item in range(batch_size): # Extract per-sample observations processed = self.preprocess_single_sample( item=item, gen_batch=gen_batch, obs=obs, ) processed_samples.append(processed) # Aggregate batch data batch = collate_fn(processed_samples) # Create DataProto with preserved metadata new_batch = DataProto.from_single_dict( data=batch, meta_info=gen_batch.meta_info ) return new_batch def gather_rollout_data( self, total_batch_list: List[List[Dict]], episode_rewards: np.ndarray, episode_lengths: np.ndarray, success: Dict[str, np.ndarray], traj_uid: np.ndarray, tool_callings: np.ndarray, ) -> DataProto: """ Collect and organize trajectory data, handling batch size adjustments to meet parallel training requirements. Parameters: total_batch_list (List[List[Dict]): List of trajectory data for each environment episode_rewards (np.ndarray): Total rewards for each environment episode_lengths (np.ndarray): Total steps for each environment success (Dict[str, np.ndarray]): Success samples for each environment traj_uid (np.ndarray): Trajectory unique identifiers tool_callings (np.ndarray): Number of tool callings for each environment Returns: DataProto: Collected and organized trajectory data """ batch_size = len(total_batch_list) success_rate = {} for key, value in success.items(): success_rate[key] = np.mean(value) effective_batch = [] for bs in range(batch_size): # sum the rewards for each data in total_batch_list[bs] for data in total_batch_list[bs]: assert traj_uid[bs] == data['traj_uid'], "data is not from the same trajectory" if data['active_masks']: # episode_rewards data['episode_rewards'] = episode_rewards[bs] # episode_lengths data['episode_lengths'] = episode_lengths[bs] # tool_callings data['tool_callings'] = tool_callings[bs] # success_rate for key, value in success_rate.items(): data[key] = value effective_batch.append(data) # Convert trajectory data to DataProto format gen_batch_output = DataProto.from_single_dict( data=collate_fn(effective_batch) ) return gen_batch_output def vanilla_multi_turn_loop( self, gen_batch: DataProto, actor_rollout_wg, envs: EnvironmentManagerBase, ) -> DataProto: """ Collects trajectories through parallel agent-environment agent_loop. Parameters: gen_batch (DataProto): Initial batch with prompts to start the agent_loop actor_rollout_wg (WorkerGroup): Worker group containing the actor model for policy decisions envs (EnvironmentManagerBase): Environment manager containing parallel environment instances Returns: total_batch_list (List[Dict]): List of trajectory data for each environment episode_rewards (np.ndarray): Total rewards for each environment episode_lengths (np.ndarray): Total steps for each environment success (Dict[str, np.ndarray]): Success samples for each environment traj_uid (np.ndarray): Trajectory unique identifiers """ batch_size = len(gen_batch.batch) # Initial observations from the environment obs, infos = envs.reset(kwargs=gen_batch.non_tensor_batch.pop('env_kwargs', None)) lenght_obs = len(obs['text']) if obs['text'] is not None else len(obs['image']) assert len(gen_batch.batch) == lenght_obs, f"gen_batch size {len(gen_batch.batch)} does not match obs size {lenght_obs}" if self.config.env.rollout.n > 0: # env grouping uid_batch = [] for i in range(batch_size): if i % self.config.env.rollout.n == 0: uid = str(uuid.uuid4()) uid_batch.append(uid) uid_batch = np.array(uid_batch, dtype=object) else: # no env grouping, set all to the same uid uid = str(uuid.uuid4()) uid_batch = np.array([uid for _ in range(len(gen_batch.batch))], dtype=object) is_done = np.zeros(batch_size, dtype=bool) traj_uid = np.array([str(uuid.uuid4()) for _ in range(batch_size)], dtype=object) total_batch_list = [[] for _ in range(batch_size)] total_infos = [[] for _ in range(batch_size)] episode_lengths = np.zeros(batch_size, dtype=np.float32) episode_rewards = np.zeros(batch_size, dtype=np.float32) tool_callings = np.zeros(batch_size, dtype=np.float32) import time as _time _total_preprocess_time = 0.0 _total_infer_time = 0.0 _total_env_time = 0.0 # Trajectory collection loop for _step in range(self.config.env.max_steps): active_masks = np.logical_not(is_done) _t0 = _time.time() batch = self.preprocess_batch(gen_batch=gen_batch, obs=obs) _total_preprocess_time += _time.time() - _t0 batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] if "multi_modal_data" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("multi_modal_data") if "raw_prompt" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("raw_prompt") if "tools_kwargs" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("tools_kwargs") batch_input = batch.pop( batch_keys=batch_keys_to_pop, non_tensor_batch_keys=non_tensor_batch_keys_to_pop, ) batch_input.meta_info = gen_batch.meta_info # pad to be divisible by dp_size batch_input_padded, pad_size = pad_dataproto_to_divisor(batch_input, actor_rollout_wg.world_size) _t0 = _time.time() batch_output_padded = actor_rollout_wg.generate_sequences(batch_input_padded) _total_infer_time += _time.time() - _t0 # # unpad batch_output = unpad_dataproto(batch_output_padded, pad_size=pad_size) batch.non_tensor_batch['uid'] = uid_batch batch.non_tensor_batch['traj_uid'] = traj_uid batch = batch.union(batch_output) text_actions = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True) _t0 = _time.time() next_obs, rewards, dones, infos = envs.step(text_actions) _total_env_time += _time.time() - _t0 if len(rewards.shape) == 2: rewards = rewards.squeeze(1) if len(dones.shape) == 2: # dones is numpy, delete a dimension dones = dones.squeeze(1) if 'is_action_valid' in infos[0]: batch.non_tensor_batch['is_action_valid'] = np.array([info['is_action_valid'] for info in infos], dtype=bool) else: batch.non_tensor_batch['is_action_valid'] = np.ones(batch_size, dtype=bool) if 'tool_calling' in infos[0]: tool_callings[active_masks] += np.array([info['tool_calling'] for info in infos], dtype=np.float32)[active_masks] # Create reward tensor, only assign rewards for active environments # episode_rewards += torch_to_numpy(rewards) * torch_to_numpy(active_masks) episode_rewards[active_masks] += torch_to_numpy(rewards)[active_masks] episode_lengths[active_masks] += 1 assert len(rewards) == batch_size, f"env should return rewards for all environments, got {len(rewards)} rewards for {batch_size} environments" batch.non_tensor_batch['rewards'] = torch_to_numpy(rewards, is_object=True) batch.non_tensor_batch['active_masks'] = torch_to_numpy(active_masks, is_object=True) # Update episode lengths for active environments batch_list: list[dict] = to_list_of_dict(batch) for i in range(batch_size): total_batch_list[i].append(batch_list[i]) total_infos[i].append(infos[i]) # Update done states is_done = np.logical_or(is_done, dones) # Update observations for next step obs = next_obs # Break if all environments are done if is_done.all(): break success: Dict[str, np.ndarray] = envs.success_evaluator( total_infos=total_infos, total_batch_list=total_batch_list, episode_rewards=episode_rewards, episode_lengths=episode_lengths, ) rollout_timing = {"inference_s": _total_infer_time, "env_s": _total_env_time, "preprocess_s": _total_preprocess_time} return total_batch_list, episode_rewards, episode_lengths, success, traj_uid, tool_callings, rollout_timing def dynamic_multi_turn_loop( self, gen_batch: DataProto, actor_rollout_wg, envs: EnvironmentManagerBase, ) -> DataProto: """ Conduct dynamic rollouts until a target batch size is met. Keeps sampling until the desired number of effective trajectories is collected. Adopted from DAPO (https://arxiv.org/abs/2503.14476) Args: gen_batch (DataProto): Initial batch for rollout. actor_rollout_wg: Actor model workers for generating responses. envs (EnvironmentManagerBase): Environment manager instance. Returns: total_batch_list (List[Dict]): Complete set of rollout steps. total_episode_rewards (np.ndarray): Accumulated rewards. total_episode_lengths (np.ndarray): Lengths per episode. total_success (Dict[str, np.ndarray]): Success metrics. total_traj_uid (np.ndarray): Trajectory IDs. """ total_batch_list = [] total_episode_rewards = [] total_episode_lengths = [] total_success = [] total_traj_uid = [] total_tool_callings = [] try_count: int = 0 max_try_count = self.config.algorithm.filter_groups.max_num_gen_batches while len(total_batch_list) < self.config.data.train_batch_size * self.config.env.rollout.n and try_count < max_try_count: if len(total_batch_list) > 0: print(f"valid num={len(total_batch_list)} < target num={self.config.data.train_batch_size * self.config.env.rollout.n}. Keep generating... ({try_count}/{max_try_count})") try_count += 1 batch_list, episode_rewards, episode_lengths, success, traj_uid, tool_callings, _ = self.vanilla_multi_turn_loop( gen_batch=gen_batch, actor_rollout_wg=actor_rollout_wg, envs=envs, ) batch_list, episode_rewards, episode_lengths, success, traj_uid, tool_callings = filter_group_data(batch_list=batch_list, episode_rewards=episode_rewards, episode_lengths=episode_lengths, success=success, traj_uid=traj_uid, tool_callings=tool_callings, config=self.config, last_try=(try_count == max_try_count), ) total_batch_list += batch_list total_episode_rewards.append(episode_rewards) total_episode_lengths.append(episode_lengths) total_success.append(success) total_traj_uid.append(traj_uid) total_tool_callings.append(tool_callings) total_episode_rewards = np.concatenate(total_episode_rewards, axis=0) total_episode_lengths = np.concatenate(total_episode_lengths, axis=0) total_success = {key: np.concatenate([success[key] for success in total_success], axis=0) for key in total_success[0].keys()} total_traj_uid = np.concatenate(total_traj_uid, axis=0) total_tool_callings = np.concatenate(total_tool_callings, axis=0) return total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, total_tool_callings, {} def multi_turn_loop( self, gen_batch: DataProto, actor_rollout_wg, envs: EnvironmentManagerBase, is_train: bool = True, ) -> DataProto: """ Select and run the appropriate rollout loop (dynamic or vanilla). Args: gen_batch (DataProto): Initial prompt batch. actor_rollout_wg: Actor model workers. envs (EnvironmentManagerBase): Environment manager for interaction. is_train (bool): Whether in training mode (affects dynamic sampling). Returns: DataProto: Final collected trajectory data with metadata. """ if is_train: gen_batch = gen_batch.repeat(repeat_times=self.config.env.rollout.n, interleave=True) # Initial observations from the environment if self.config.algorithm.filter_groups.enable and is_train: # Dynamic Sampling (for DAPO and Dynamic GiGPO) total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, totoal_tool_callings, rollout_timing = \ self.dynamic_multi_turn_loop( gen_batch=gen_batch, actor_rollout_wg=actor_rollout_wg, envs=envs, ) else: # Vanilla Sampling total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid, totoal_tool_callings, rollout_timing = \ self.vanilla_multi_turn_loop( gen_batch=gen_batch, actor_rollout_wg=actor_rollout_wg, envs=envs, ) assert len(total_batch_list) == len(total_episode_rewards) assert len(total_batch_list) == len(total_episode_lengths) assert len(total_batch_list) == len(total_traj_uid) assert len(total_batch_list) == len(totoal_tool_callings) # Create trajectory data gen_batch_output: DataProto = self.gather_rollout_data( total_batch_list=total_batch_list, episode_rewards=total_episode_rewards, episode_lengths=total_episode_lengths, success=total_success, traj_uid=total_traj_uid, tool_callings=totoal_tool_callings, ) if gen_batch_output.meta_info is None: gen_batch_output.meta_info = {} gen_batch_output.meta_info['rollout_timing'] = rollout_timing return gen_batch_output