| from transformers import LlavaForConditionalGeneration,PretrainedConfig |
| from configuration_bit_vla import Bitvla_Config |
| import numpy as np |
| import torch |
| from prismatic.vla.constants import ( |
| ACTION_DIM, |
| ACTION_PROPRIO_NORMALIZATION_TYPE, |
| NUM_ACTIONS_CHUNK, |
| NormalizationType, |
| ) |
| from typing import Optional, Dict, Any,List,Tuple |
|
|
| from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast |
|
|
| from prismatic.training.train_utils import ( |
| get_current_action_mask, |
| get_next_actions_mask, |
| ) |
|
|
|
|
| class BitVLAForActionPrediction(LlavaForConditionalGeneration): |
| config_class: PretrainedConfig = Bitvla_Config |
|
|
| def __init__(self, config) -> None: |
| super().__init__(config) |
| self.norm_stats = config.norm_stats |
|
|
| |
| self.bins = np.linspace(-1, 1, config.n_action_bins) |
| self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 |
|
|
| |
| |
| |
| self.vocab_size = self.config.vocab_size |
|
|
| def set_constant(self, image_token_idx, proprio_pad_idx, ignore_idx, action_token_begin_idx, stop_index): |
| self.image_token_idx = image_token_idx |
| self.proprio_pad_idx = proprio_pad_idx |
| self.action_token_begin_idx = action_token_begin_idx |
| self.stop_index = stop_index |
| self.ignore_idx = ignore_idx |
| |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_projector_features: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| proprio=None, |
| proprio_projector=None, |
| cache_position: Optional[torch.LongTensor] = None, |
| vision_feature_layer=None, |
| vision_feature_select_strategy=None, |
| ) -> Tuple[int, LlavaCausalLMOutputWithPast]: |
| """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| output_projector_features = output_projector_features if output_projector_features is not None else False |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| use_cache = use_cache and not self.training |
|
|
| batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
|
|
| |
| if (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" |
|
|
| |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
| |
| if pixel_values is not None: |
| vision_feature_layer = ( |
| vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
| ) |
| vision_feature_select_strategy = ( |
| vision_feature_select_strategy |
| if vision_feature_select_strategy is not None |
| else self.config.vision_feature_select_strategy |
| ) |
| |
| |
| |
| |
| b, num_images, c, h, w = pixel_values.shape |
| pixel_values = pixel_values.view(-1, c, h, w) |
| image_embeds = self.get_image_features( |
| pixel_values = pixel_values, |
| vision_feature_layer = vision_feature_layer, |
| vision_feature_select_strategy = vision_feature_select_strategy, |
| ) |
| |
| |
| image_embeds = image_embeds.view(-1,image_embeds.shape[-1]) |
| n_image_tokens = (input_ids == self.image_token_idx).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.image_token_idx |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| |
| |
| |
| if proprio_projector is not None and proprio is not None: |
| |
| proprio = proprio.reshape(batch_size, -1) |
| proprio_features = proprio_projector(proprio) |
| proprio_features = proprio_features.unsqueeze(dim=1) |
| |
| proprio_features = proprio_features.view(-1, proprio_features.shape[-1]) |
| n_proprio_tokens = (input_ids == self.proprio_pad_idx).sum().item() |
| n_proprio_features = proprio_features.shape[0] |
| if n_proprio_tokens != n_proprio_features: |
| raise ValueError( |
| f"Proprio features and proprio tokens do not match: tokens: {n_proprio_tokens}, features {n_proprio_features}" |
| ) |
| |
| mask = input_ids == self.proprio_pad_idx |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| proprio_mask = mask_expanded.to(inputs_embeds.device) |
| |
| proprio_features = proprio_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(proprio_mask, proprio_features) |
| |
| |
| |
| |
| all_actions_mask = (labels != self.ignore_idx) & (labels != self.stop_index) |
| |
| |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| inputs_embeds = inputs_embeds * ~all_actions_mask |
| outputs = LlavaForConditionalGeneration.forward( |
| self, |
| input_ids = None, |
| attention_mask=attention_mask, |
| position_ids=None, |
| pixel_values=None, |
| labels=labels, |
| inputs_embeds=inputs_embeds, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| |
| elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
| raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
|
|
| else: |
| raise ValueError( |
| "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
| f"=> `input_ids` = {input_ids is not None}\n" |
| f"=> `attention_mask` = {attention_mask is not None}\n" |
| f"=> `pixel_values` = {pixel_values is not None}\n" |
| f"=> `labels` = {labels is not None}\n" |
| f"=> `input_embeds` = {inputs_embeds is not None}\n" |
| f"=> `past_key_values` = {past_key_values is not None}\n" |
| f"=> `use_cache` = {use_cache}" |
| ) |
|
|
| return outputs |
|
|
| def _prepare_input_for_action_prediction(self, input_ids, attention_mask): |
| """Prepares input for action prediction by adding necessary tokens""" |
| |
| placeholder_action_token_ids = ( |
| torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) |
| ) |
| input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) |
|
|
| |
| stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * self.stop_index |
| input_ids = torch.cat([input_ids, stop_token_id], dim=-1) |
|
|
| |
| |
| mask_extension = ( |
| torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) |
| .to(attention_mask.device) |
| .to(attention_mask.dtype) |
| ) |
| attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) |
|
|
| return input_ids, attention_mask |
|
|
| def _prepare_labels_for_action_prediction(self, labels, input_ids): |
| """Creates labels tensor for action prediction if not provided""" |
| |
| ARBITRARY_ACTION_TOKEN_IDX = self.action_token_begin_idx + 1 |
| labels_extension = ( |
| torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) |
| * ARBITRARY_ACTION_TOKEN_IDX |
| ) |
| labels = torch.cat([labels, labels_extension], dim=-1) |
|
|
| |
| labels[:, -1] = self.stop_index |
|
|
| return labels |
|
|
| def _process_action_masks(self, labels): |
| """Helper to get action masks from labels""" |
| current_action_mask = get_current_action_mask(labels,ignore_index=self.ignore_idx,action_token_begin_idx=self.action_token_begin_idx) |
| next_actions_mask = get_next_actions_mask(labels,ignore_index=self.ignore_idx,action_token_begin_idx=self.action_token_begin_idx) |
| all_actions_mask = current_action_mask | next_actions_mask |
| return all_actions_mask |
| |
| def _unnormalize_actions(self, normalized_actions, unnorm_key=None): |
| """Unnormalize actions using dataset statistics""" |
| action_norm_stats = self.get_action_stats(unnorm_key) |
|
|
| if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) |
| elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) |
| else: |
| raise ValueError("Unsupported action/proprio normalization type detected!") |
|
|
| actions = np.where( |
| mask, |
| 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, |
| normalized_actions, |
| ) |
|
|
| return actions |
| |
| def _regression_or_discrete_prediction( |
| self, |
| input_ids, |
| input_embeddings, |
| all_actions_mask, |
| attention_mask, |
| labels, |
| action_head=None, |
| pixel_values = None, |
| ): |
| """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| input_embeddings = input_embeddings * ~all_actions_mask |
|
|
| llava_output = LlavaForConditionalGeneration.forward( |
| self, |
| input_ids = None, |
| attention_mask=attention_mask, |
| position_ids=None, |
| pixel_values=None, |
| labels=None, |
| inputs_embeds=input_embeddings, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| all_actions_mask = self._process_action_masks(labels[:,1:]) |
| |
| last_hidden_states = llava_output.hidden_states[-1] |
| last_hidden_states = last_hidden_states[:, : -1, :] |
| |
| actions_hidden_states = last_hidden_states[all_actions_mask.squeeze(-1)].unsqueeze(0) |
|
|
| |
| if action_head is not None: |
| |
| normalized_actions = action_head.predict_action(actions_hidden_states) |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
| normalized_actions = normalized_actions.float().cpu().detach().numpy() |
| else: |
| |
| predicted_action_token_ids = ( |
| llava_output.logits[all_actions_mask.squeeze(-1)].unsqueeze(0) |
| .argmax(dim=2) |
| .cpu() |
| .numpy() |
| ) |
| |
| discretized_actions = self.vocab_size - predicted_action_token_ids |
| discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| return normalized_actions, actions_hidden_states |
| |
| def predict_action( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| unnorm_key: Optional[str] = None, |
| proprio=None, |
| proprio_projector=None, |
| action_head=None, |
| vision_feature_layer=None, |
| vision_feature_select_strategy=None, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """Predict actions from input sequence, with options for different prediction methods. |
| |
| Args: |
| input_ids: Input token ids |
| unnorm_key: Key for unnormalization statistics |
| proprio: Proprioceptive features |
| proprio_projector: Projector for proprioceptive features |
| action_head: Optional head for L1 regression prediction |
| **kwargs: Additional arguments including pixel_values and attention_mask |
| |
| Returns: |
| Tuple of (unnormalized_actions, action_hidden_states) |
| """ |
| pixel_values = kwargs["pixel_values"] |
| attention_mask = kwargs["attention_mask"] |
|
|
| |
| labels = input_ids.clone() |
| labels[:] = self.ignore_idx |
|
|
| |
| input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) |
|
|
| |
| labels = self._prepare_labels_for_action_prediction(labels, input_ids) |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
| all_actions_mask = self._process_action_masks(labels) |
| |
| |
| if pixel_values is not None: |
| vision_feature_layer = ( |
| vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
| ) |
| vision_feature_select_strategy = ( |
| vision_feature_select_strategy |
| if vision_feature_select_strategy is not None |
| else self.config.vision_feature_select_strategy |
| ) |
| |
| |
| |
| |
| b, num_images, c, h, w = pixel_values.shape |
| pixel_values = pixel_values.view(-1, c, h, w) |
| image_embeds = self.get_image_features( |
| pixel_values = pixel_values, |
| vision_feature_layer = vision_feature_layer, |
| vision_feature_select_strategy = vision_feature_select_strategy, |
| ) |
| |
| |
| image_embeds = image_embeds.view(-1,image_embeds.shape[-1]) |
| n_image_tokens = (input_ids == self.image_token_idx).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.image_token_idx |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(input_embeddings) |
| image_mask = mask_expanded.to(input_embeddings.device) |
|
|
| image_embeds = image_embeds.to(input_embeddings.device, input_embeddings.dtype) |
| input_embeddings = input_embeddings.masked_scatter(image_mask, image_embeds) |
|
|
| |
| use_proprio = proprio_projector is not None and proprio is not None |
| if use_proprio: |
| batch_size = input_ids.shape[0] if input_ids is not None else input_embeddings.shape[0] |
| proprio = torch.Tensor(proprio).to(input_embeddings.device, dtype=input_embeddings.dtype) |
| if proprio_projector is not None and proprio is not None: |
| |
| proprio = proprio.reshape(batch_size, -1) |
| proprio_features = proprio_projector(proprio) |
| proprio_features = proprio_features.unsqueeze(dim=1) |
| |
| proprio_features = proprio_features.view(-1, proprio_features.shape[-1]) |
| n_proprio_tokens = (input_ids == self.proprio_pad_idx).sum().item() |
| n_proprio_features = proprio_features.shape[0] |
| if n_proprio_tokens != n_proprio_features: |
| raise ValueError( |
| f"Proprio features and proprio tokens do not match: tokens: {n_proprio_tokens}, features {n_proprio_features}" |
| ) |
| |
| mask = input_ids == self.proprio_pad_idx |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(input_embeddings) |
| proprio_mask = mask_expanded.to(input_embeddings.device) |
| |
| proprio_features = proprio_features.to(input_embeddings.device, input_embeddings.dtype) |
| input_embeddings = input_embeddings.masked_scatter(proprio_mask, proprio_features) |
|
|
| |
| normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( |
| input_ids, |
| input_embeddings, |
| all_actions_mask, |
| attention_mask, |
| labels, |
| action_head, |
| pixel_values, |
| ) |
|
|
| |
| actions = self._unnormalize_actions(normalized_actions, unnorm_key) |
|
|
| return actions, actions_hidden_states |
|
|
| @staticmethod |
| def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
| """Validate and resolve the unnormalization key for action statistics""" |
| if unnorm_key is None: |
| assert len(norm_stats) == 1, ( |
| f"Your model was trained on more than one dataset, " |
| f"please pass a `unnorm_key` from the following options to choose the statistics " |
| f"used for un-normalizing actions: {norm_stats.keys()}" |
| ) |
| unnorm_key = next(iter(norm_stats.keys())) |
|
|
| assert unnorm_key in norm_stats, ( |
| f"The `unnorm_key` you chose is not in the set of available dataset statistics, " |
| f"please choose from: {norm_stats.keys()}" |
| ) |
| return unnorm_key |
|
|
| def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
| """Get the dimensionality of the policy's action space.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return len(self.norm_stats[unnorm_key]["action"]["min"]) |
|
|
| def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
| """Get all the logged statistics for the given dataset.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return self.norm_stats[unnorm_key]["action"] |
| |