import math from fractions import Fraction from typing import Optional, Union import numpy as np import torch from torch import nn import transformers from transformers import ( AutoModel, LlavaNextForConditionalGeneration, ) _V5 = int(transformers.__version__.split(".")[0]) >= 5 if _V5: from transformers.masking_utils import create_causal_mask from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( HybridMambaAttentionDynamicCache, ) from transformers.models.llava_next.modeling_llava_next import ( LlavaNextCausalLMOutputWithPast, LlavaNextModelOutputWithPast, LlavaNextPreTrainedModel, get_anyres_image_grid_shape, image_size_to_num_patches, unpad_image, ) from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, can_return_tuple, logging from .configuration import Granite4VisionConfig from .downsampling import WindowQFormerDownsampler logger = logging.get_logger(__name__) class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration): config_class = Granite4VisionConfig def __init__(self, config: Granite4VisionConfig): LlavaNextPreTrainedModel.__init__(self, config) self.model = Granite4VisionModel(config) self.lm_head = nn.Linear( config.text_config.hidden_size, config.text_config.vocab_size, bias=False ) self.post_init() def merge_lora_adapters(self): """Merge LoRA adapter weights into base weights in-place and disable adapter toggling.""" from peft.tuners.tuners_utils import BaseTunerLayer for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): module.merge() self._hf_peft_config_loaded = False return self def generate(self, *args, **kwargs) -> torch.LongTensor: # When loaded with a LoRA adapter, disable the adapter for text-only # inputs (no pixel_values) so the base LLM runs standalone. pixel_values = kwargs.get("pixel_values", None) if hasattr(self, "_hf_peft_config_loaded") and self._hf_peft_config_loaded: if pixel_values is not None: self.enable_adapters() else: self.disable_adapters() return super().generate(*args, **kwargs) @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, LlavaNextCausalLMOutputWithPast]: cache_position = kwargs.pop("cache_position", None) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) model_kwargs = dict( pixel_values=pixel_values, image_sizes=image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) if not _V5: model_kwargs["cache_position"] = cache_position outputs = self.model(input_ids, **model_kwargs, **kwargs) hidden_states = outputs.last_hidden_state loss = None logits = self.lm_head(hidden_states) logits = logits / self.config.text_config.logits_scaling if labels is not None: loss = self.loss_function( logits, labels, vocab_size=self.config.text_config.vocab_size, **kwargs, ) if isinstance(logits_to_keep, int) and logits_to_keep > 0: logits = logits[:, -logits_to_keep:, :] return LlavaNextCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, image_sizes=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): if _V5: is_first = kwargs.get("is_first_iteration", False) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, logits_to_keep=logits_to_keep, **kwargs, ) else: is_first = cache_position[0] == 0 if cache_position is not None else True model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) model_inputs = self._init_hybrid_cache(**model_inputs) if is_first: model_inputs["pixel_values"] = pixel_values model_inputs["image_sizes"] = image_sizes return model_inputs def _init_hybrid_cache( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): """Handle HybridMambaAttentionDynamicCache for GraniteMoeHybrid language model.""" empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values.get_seq_length() == 0) if not empty_past_kv and not _V5: if ( inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1] ): input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] elif use_cache and empty_past_kv: past_key_values = HybridMambaAttentionDynamicCache( self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device ) if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if not empty_past_kv: position_ids = position_ids[:, -input_ids.shape[1] :] if inputs_embeds is not None and empty_past_kv: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) if not _V5: model_inputs["cache_position"] = cache_position for key, value in kwargs.items(): if key not in model_inputs: model_inputs[key] = value return model_inputs class Granite4VisionModel(LlavaNextPreTrainedModel): config_class = Granite4VisionConfig def __init__(self, config: Granite4VisionConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.spatial_projectors = None assert config.deepstack_layer_map is not None assert config.downsample_rate is not None self.downsample_rate = config.downsample_rate # Deepstack projectors: one per (vision_layer, llm_layer) pair self.layerwise_projectors = nn.ModuleList([ WindowQFormerDownsampler(config) for _ in range(len(config.deepstack_layer_map)) ]) # Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR) if config.use_spatial_sampling: self.spatial_projectors = nn.ModuleList([ WindowQFormerDownsampler(config, spatial_offset=i) for i in range(4) ]) self.image_newline = None if config.use_image_newline_parameter: embed_std = 1 / math.sqrt(config.text_config.hidden_size) self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = getattr(self.config, "pad_token_id", None) or -1 self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def pack_and_unpad_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Args: image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). vision_feature_select_strategy (`str`) The feature selection strategy used to select the vision feature from the vision backbone. image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. Returns: image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) feature_lens (`list[int]`) token length of each image in image_features """ new_image_features = [] feature_lens = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) if self.layerwise_projectors is not None: ds_rate = Fraction(self.downsample_rate) height = int(height * ds_rate) width = int(width * ds_rate) if ( np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 and vision_feature_select_strategy == "default" ): logger.warning_once( "Image feature shape does not line up with the provided patch size. " "You may be using the `default` vision_feature_select_strategy with a" " visual encoder that does not have CLS." ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) if image_newline is not None: image_feature = torch.cat( ( image_feature, image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device, image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) return new_image_features, feature_lens def get_image_features( self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, ): """ Extract image features via deepstack (multi-layer) and spatial sampling projections. Runs the vision tower once, then: 1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map, extracts features from that vision layer, downsamples via interpolation + QFormer, and pairs them with the target LLM layer. 2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial offset groups (TL, TR, BL, BR), each targeting a different LLM layer. Args: pixel_values: Image tensors of shape (batch, num_patches, C, H, W) or (N, C, H, W). image_sizes: Actual image sizes (num_images, 2). vision_feature_layer: Unused (kept for API compatibility). vision_feature_select_strategy: "default" (remove CLS) or "full". Returns: List of (llm_layer_idx, packed_features) tuples for injection during forward pass. """ vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) image_num_patches = [ image_size_to_num_patches( image_size=imsize, grid_pinpoints=self.config.image_grid_pinpoints, patch_size=self.config.vision_config.image_size, ) for imsize in image_sizes ] if pixel_values.dim() == 5: _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] pixel_values = torch.cat(_pixel_values_list, dim=0) elif pixel_values.dim() != 4: raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # Deepstack features: extract from multiple vision layers, downsample via interpolation all_features = [] for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map): selected_feature = vision_outputs.hidden_states[vision_layer] if vision_feature_select_strategy == "default": selected_feature = selected_feature[:, 1:] projected_features = self.layerwise_projectors[projection_idx](selected_feature) projected_features = torch.split(projected_features, image_num_patches, dim=0) packed_features, _ = self.pack_and_unpad_image_features( projected_features, image_sizes, vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) all_features.append((llm_layer, packed_features)) # Spatial features: extract 4 offset groups from a single vision layer if self.config.use_spatial_sampling: spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer] if vision_feature_select_strategy == "default": spatial_feature = spatial_feature[:, 1:] for group_idx, llm_layer in enumerate(self.config.spatial_target_layers): projected_group = self.spatial_projectors[group_idx](spatial_feature) projected_group_split = torch.split(projected_group, image_num_patches, dim=0) packed_group, _ = self.pack_and_unpad_image_features( projected_group_split, image_sizes, vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) all_features.append((llm_layer, packed_group)) return all_features def get_image_token_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): """ Build a boolean mask over inputs_embeds marking positions of tokens, and verify that the count matches the number of image feature vectors. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" ) return special_image_mask @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, LlavaNextModelOutputWithPast]: cache_position = kwargs.pop("cache_position", None) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) # Extract deepstack + spatial features and prepare for layer-by-layer injection deepstack_features = [] vision_mask = None image_features = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) for idx, (llm_layer_idx, packed_features) in enumerate(image_features): concat_features = torch.cat(packed_features, dim=0).to( inputs_embeds.device, inputs_embeds.dtype ) if idx == 0: vision_mask = self.get_image_token_mask( input_ids, inputs_embeds=inputs_embeds, image_features=concat_features ) inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0) deepstack_features.append((llm_layer_idx, concat_features)) # Custom forward pass with vision injection at specific LLM layers hidden_states = inputs_embeds * self.language_model.embedding_multiplier if _V5: if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ).unsqueeze(0) causal_mask = create_causal_mask( config=self.language_model.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, ) mamba_mask = self.language_model._update_mamba_mask(attention_mask, past_key_values) else: if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self.language_model._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) mamba_mask = self.language_model._update_mamba_mask(attention_mask, cache_position) position_embeddings = None if self.language_model.rotary_emb is not None: position_embeddings = self.language_model.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None # Layer-by-layer forward with vision injection for layer_idx, decoder_layer in enumerate(self.language_model.layers): # Inject vision features at this layer if configured for target_layer, features_for_layer in deepstack_features: if layer_idx == target_layer: hidden_states = hidden_states.masked_scatter( vision_mask, (hidden_states[vision_mask] + features_for_layer.flatten()).view(-1) ) layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask if output_hidden_states: all_hidden_states += (hidden_states,) layer_kwargs = dict( attention_mask=layer_mask, past_key_values=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, ) if not _V5: layer_kwargs["output_attentions"] = output_attentions layer_kwargs["cache_position"] = cache_position layer_outputs = decoder_layer(hidden_states, **layer_kwargs, **kwargs) # v5 decoder layers return a bare tensor; v4 returns a tuple if isinstance(layer_outputs, torch.Tensor): hidden_states = layer_outputs else: hidden_states = layer_outputs[0] if output_attentions and layer_outputs[1] is not None: all_self_attns += (layer_outputs[1],) hidden_states = self.language_model.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True return LlavaNextModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, image_hidden_states=image_features if pixel_values is not None else None, )