| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch VideoLLaMA3 model.""" |
|
|
| import importlib.util |
| import os.path as osp |
| import re |
| from abc import ABC, abstractmethod |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint |
|
|
| from transformers import AutoModel, Qwen2ForCausalLM, Qwen2Model |
| from transformers.generation.utils import GenerateOutput |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| try: |
| from .configuration_videollama3 import Videollama3Qwen2Config |
| except ModuleNotFoundError: |
| spec = importlib.util.spec_from_file_location( |
| "configuration_videollama3", |
| osp.join(osp.dirname(__file__), "configuration_videollama3.py"), |
| ) |
| configuration_videollama3 = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(configuration_videollama3) |
| Videollama3Qwen2Config = getattr( |
| configuration_videollama3, |
| "Videollama3Qwen2Config", |
| ) |
|
|
|
|
| def build_mlp(depth, hidden_size, output_hidden_size): |
| modules = [nn.Linear(hidden_size, output_hidden_size)] |
| for _ in range(1, depth): |
| modules.append(nn.GELU()) |
| modules.append(nn.Linear(output_hidden_size, output_hidden_size)) |
| return nn.Sequential(*modules) |
|
|
|
|
| def build_vision_projector(config, delay_load=False, **kwargs): |
| |
| projector_type = getattr(config, 'mm_projector_type', 'linear') |
| if projector_type == "linear": |
| |
| return nn.Linear(config.mm_hidden_size, config.hidden_size) |
| elif projector_type.startswith("mlp"): |
| return MlpGeluProjector(config, projector_type) |
| else: |
| raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
|
| class MlpGeluProjector(nn.Module): |
|
|
| def __init__(self, config, projector_type): |
| super().__init__() |
|
|
| mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) |
| mlp_depth = int(mlp_gelu_match.group(1)) |
|
|
| self.readout = build_mlp(mlp_depth, config.vision_encoder_config.hidden_size, config.hidden_size) |
|
|
| def forward(self, x): |
| x = self.readout(x) |
| return x |
|
|
|
|
| class Videollama3MetaModel: |
|
|
| def __init__(self, config): |
| super(Videollama3MetaModel, self).__init__(config) |
| if config.vision_encoder is not None: |
| self.vision_encoder = AutoModel.from_pretrained( |
| config.vision_encoder, |
| attn_implementation=self.config._attn_implementation, |
| torch_dtype=self.dtype, |
| ) |
| self.config.vision_encoder_config = self.vision_encoder.config |
| self.config.vision_encoder = None |
| elif config.vision_encoder_config is not None: |
| self.vision_encoder = AutoModel.from_config( |
| self.config.vision_encoder_config, |
| attn_implementation=self.config._attn_implementation, |
| torch_dtype=self.dtype, |
| ) |
| else: |
| raise ValueError("Vision encoder is not provided in config") |
| self.mm_projector = build_vision_projector(config) |
|
|
| def get_vision_encoder(self): |
| return self.vision_encoder |
|
|
| def get_mm_projector(self): |
| return self.mm_projector |
|
|
|
|
| class Videollama3Qwen2Model(Videollama3MetaModel, Qwen2Model): |
|
|
| config_class = Videollama3Qwen2Config |
|
|
| def __init__(self, config: Videollama3Qwen2Config): |
| super(Videollama3Qwen2Model, self).__init__(config) |
|
|
|
|
| class Videollama3MetaForCausalLM(ABC): |
|
|
| @abstractmethod |
| def get_model(self): |
| pass |
|
|
| def get_vision_encoder(self): |
| return self.get_model().get_vision_encoder() |
|
|
| def get_mm_projector(self): |
| return self.get_model().get_mm_projector() |
|
|
| def encode_images( |
| self, |
| pixel_values: torch.FloatTensor, |
| grid_sizes: torch.LongTensor, |
| merge_sizes: torch.LongTensor, |
| ) -> torch.FloatTensor: |
| mm_features = self.get_model().get_vision_encoder()( |
| pixel_values=pixel_values, |
| grid_sizes=grid_sizes, |
| merge_sizes=merge_sizes, |
| ) |
| mm_features = self.get_model().mm_projector(mm_features) |
| return mm_features |
|
|
| def _get_valid_visual_tokens( |
| self, |
| mm_features: torch.FloatTensor, |
| batched_num_patches: torch.LongTensor, |
| modals: List[str], |
| ): |
| valid_masks = [] |
| for num_patches, modal in zip(batched_num_patches, modals): |
| valid_mask = torch.full((num_patches, ), modal != "text", dtype=torch.bool, device=mm_features.device) |
| valid_masks.append(valid_mask) |
| mm_features = mm_features[torch.cat(valid_masks)] |
| return mm_features |
|
|
| def _maybe_truncate_visual_tokens( |
| self, |
| mm_features: torch.FloatTensor, |
| compression_mask: torch.BoolTensor, |
| batched_num_patches: torch.LongTensor, |
| modals: List[str], |
| input_ids: torch.LongTensor, |
| position_ids: Optional[torch.LongTensor] = None, |
| ): |
| if position_ids is None or mm_features.shape[0] == input_ids.eq(self.config.image_token_index).sum(): |
| return mm_features, compression_mask |
|
|
| truncation_mask = [] |
| for num_patches, modal in zip(batched_num_patches, modals): |
| if modal == "text": |
| truncation_mask.append(torch.ones((0,), dtype=torch.bool, device=input_ids.device)) |
| else: |
| truncation_mask.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device)) |
|
|
| seq_end_indices = torch.nonzero(position_ids == 0)[:, 0] |
| seq_end_indices = seq_end_indices[seq_end_indices > 0].tolist()+ [len(input_ids)] |
| seq_start_indices = [0] + seq_end_indices[:-1] |
| num_visual_tokens = [ |
| input_ids[start:end].eq(self.config.image_token_index).sum() |
| for start, end in zip(seq_start_indices, seq_end_indices) |
| ] |
|
|
| for n, mask in zip(num_visual_tokens, truncation_mask): |
| if len(mask) > 0: |
| mask[n:] = False |
| truncation_mask = torch.cat(truncation_mask) |
|
|
| return mm_features[truncation_mask], compression_mask[truncation_mask] |
|
|
| def _get_compression_mask( |
| self, |
| pixel_values: torch.FloatTensor, |
| batched_num_patches: torch.LongTensor, |
| grid_sizes: torch.LongTensor, |
| merge_sizes: torch.LongTensor, |
| modals: List[str], |
| threshold: float = 0.1, |
| min_tokens: int = 1, |
| ) -> torch.BoolTensor: |
| batched_images = pixel_values.split(grid_sizes.prod(dim=1).tolist(), dim=0) |
| compression_masks = [] |
|
|
| for images, num_patches, grid_size, merge_size, modal in zip( |
| batched_images, batched_num_patches, grid_sizes, merge_sizes, modals |
| ): |
| t, h, w = grid_size |
| if modal == "image" or (modal == "video" and t == 1): |
| compression_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=images.device)) |
|
|
| elif modal == "video": |
| |
| images = images.view(t, (h // merge_size) * (w // merge_size), -1) |
|
|
| pixel_diff = images[1:] - images[:-1] |
| pixel_diff = torch.abs(pixel_diff).mean(dim=-1) * 255 |
| pixel_diff = torch.cat([torch.full_like(pixel_diff[0:1], threshold + 1), pixel_diff], dim=0) |
| mask = pixel_diff > threshold |
| padding_ids = torch.nonzero(mask.sum(dim=1) < min_tokens)[:, 0] |
| |
| mask[padding_ids, :min_tokens] = 1 |
| compression_masks.append(mask.flatten()) |
|
|
| else: |
| |
| compression_masks.append(torch.ones((0,), dtype=torch.bool, device=images.device)) |
|
|
| return torch.cat(compression_masks) |
|
|
| def _compress_visual_tokens( |
| self, |
| compression_mask: torch.BoolTensor, |
| mm_features: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| ): |
| mm_features = mm_features[compression_mask] |
| image_selected = (input_ids == self.config.image_token_index) |
|
|
| text_masks = torch.logical_not(image_selected) |
| text_masks[image_selected] = compression_mask |
| input_ids = input_ids[text_masks] |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask[text_masks] |
| if labels is not None: |
| labels = labels[text_masks] |
| if position_ids is not None: |
| |
| position_ids = position_ids[text_masks] |
| pos_start = [0] + torch.nonzero(position_ids == 0)[:, 0].tolist() |
| pos_end = pos_start[1:] + [len(input_ids)] |
| position_ids = torch.cat([torch.arange(end - start, device=input_ids.device) for start, end in zip(pos_start, pos_end)]) |
|
|
| return mm_features, input_ids, attention_mask, position_ids, labels |
|
|
| def prepare_inputs_labels_for_multimodal( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| grid_sizes: Optional[torch.LongTensor] = None, |
| merge_sizes: Optional[torch.LongTensor] = None, |
| modals: Optional[List[str]] = None, |
| ): |
| vision_encoder = self.get_vision_encoder() |
| |
| if vision_encoder is None or pixel_values is None or input_ids.shape[1] == 1: |
| return input_ids, attention_mask, position_ids, past_key_values, None, labels |
|
|
| |
| B, N = input_ids.shape |
| input_ids = input_ids.view(B * N) |
| if attention_mask is not None: |
| attention_mask = attention_mask.view(B * N) |
| if position_ids is not None: |
| position_ids = position_ids.view(B * N) |
| if labels is not None: |
| labels = labels.view(B * N) |
|
|
| |
| batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long() |
| mm_features = self.encode_images(pixel_values, grid_sizes, merge_sizes).to(input_ids.device) |
| mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals) |
|
|
| compression_mask = self._get_compression_mask( |
| pixel_values, batched_num_patches, grid_sizes, merge_sizes, modals |
| ) |
| mm_features, compression_mask = self._maybe_truncate_visual_tokens( |
| mm_features, compression_mask, batched_num_patches, modals, input_ids, position_ids |
| ) |
|
|
| |
| if self.config.use_token_compression: |
| assert B == 1, "Token compression is only supported for batch_size=1" |
| mm_features, input_ids, attention_mask, labels, position_ids = self._compress_visual_tokens( |
| compression_mask, mm_features, input_ids, attention_mask, labels, position_ids |
| ) |
|
|
| |
| inputs_embeds = self.get_model().embed_tokens(input_ids).clone() |
|
|
| |
| image_selected = (input_ids == self.config.image_token_index) |
| inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features |
|
|
| |
| C = inputs_embeds.shape[-1] |
| inputs_embeds = inputs_embeds.reshape(B, -1, C) |
| if attention_mask is not None: |
| attention_mask = attention_mask.view(B, -1) |
| if labels is not None: |
| labels = labels.view(B, -1) |
| if position_ids is not None: |
| position_ids = position_ids.view(B, -1) |
|
|
| return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels |
|
|
|
|
| class Videollama3Qwen2ForCausalLM(Qwen2ForCausalLM, Videollama3MetaForCausalLM): |
|
|
| config_class = Videollama3Qwen2Config |
|
|
| def __init__(self, config, **kwargs): |
| super(Qwen2ForCausalLM, self).__init__(config) |
| self.model = Videollama3Qwen2Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_model(self): |
| return self.model |
|
|
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| num_logits_to_keep: int = 0, |
| |
| pixel_values: Optional[torch.FloatTensor] = None, |
| grid_sizes: Optional[torch.LongTensor] = None, |
| merge_sizes: Optional[torch.LongTensor] = None, |
| modals: Optional[List[str]] = None, |
| **loss_kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| if inputs_embeds is None: |
| ( |
| input_ids, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| inputs_embeds, |
| labels, |
| ) = self.prepare_inputs_labels_for_multimodal( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| labels=labels, |
| pixel_values=pixel_values, |
| grid_sizes=grid_sizes, |
| merge_sizes=merge_sizes, |
| modals=modals, |
| ) |
|
|
| return super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| num_logits_to_keep=num_logits_to_keep, |
| **loss_kwargs, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| |
| pixel_values: Optional[torch.FloatTensor] = None, |
| grid_sizes: Optional[torch.LongTensor] = None, |
| merge_sizes: Optional[torch.LongTensor] = None, |
| modals: Optional[List[str]] = None, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| input_ids = kwargs.pop("input_ids", None) |
| attention_mask = kwargs.pop("attention_mask", None) |
| position_ids = kwargs.pop("position_ids", None) |
| past_key_values = kwargs.pop("past_key_values", None) |
|
|
| if "inputs_embeds" in kwargs: |
| raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
| if pixel_values is not None: |
| ( |
| input_ids, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| inputs_embeds, |
| labels, |
| ) = self.prepare_inputs_labels_for_multimodal( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| labels=None, |
| pixel_values=pixel_values, |
| grid_sizes=grid_sizes, |
| merge_sizes=merge_sizes, |
| modals=modals, |
| ) |
| else: |
| inputs_embeds = self.get_model().embed_tokens(input_ids) |
|
|
| return super().generate( |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| **kwargs |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| images = kwargs.pop("images", None) |
| _inputs = super().prepare_inputs_for_generation( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| ) |
| if images is not None: |
| _inputs['images'] = images |
| return _inputs |
|
|