granite-4.0-3b-vision / modeling.py
Avihu's picture
Upload modeling.py with huggingface_hub
b2adc46 verified
import math
from fractions import Fraction
from typing import Optional, Union
import numpy as np
import torch
from torch import nn
import transformers
from transformers import (
AutoModel,
LlavaNextForConditionalGeneration,
)
_V5 = int(transformers.__version__.split(".")[0]) >= 5
if _V5:
from transformers.masking_utils import create_causal_mask
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
HybridMambaAttentionDynamicCache,
)
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextModelOutputWithPast,
LlavaNextPreTrainedModel,
get_anyres_image_grid_shape,
image_size_to_num_patches,
unpad_image,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple, logging
from .configuration import Granite4VisionConfig
from .downsampling import WindowQFormerDownsampler
logger = logging.get_logger(__name__)
class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
config_class = Granite4VisionConfig
def __init__(self, config: Granite4VisionConfig):
LlavaNextPreTrainedModel.__init__(self, config)
self.model = Granite4VisionModel(config)
self.lm_head = nn.Linear(
config.text_config.hidden_size, config.text_config.vocab_size, bias=False
)
self.post_init()
def merge_lora_adapters(self):
"""Merge LoRA adapter weights into base weights in-place and disable adapter toggling."""
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.merge()
self._hf_peft_config_loaded = False
return self
def generate(self, *args, **kwargs) -> torch.LongTensor:
# When loaded with a LoRA adapter, disable the adapter for text-only
# inputs (no pixel_values) so the base LLM runs standalone.
pixel_values = kwargs.get("pixel_values", None)
if hasattr(self, "_hf_peft_config_loaded") and self._hf_peft_config_loaded:
if pixel_values is not None:
self.enable_adapters()
else:
self.disable_adapters()
return super().generate(*args, **kwargs)
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, LlavaNextCausalLMOutputWithPast]:
cache_position = kwargs.pop("cache_position", None)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
model_kwargs = dict(
pixel_values=pixel_values,
image_sizes=image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
if not _V5:
model_kwargs["cache_position"] = cache_position
outputs = self.model(input_ids, **model_kwargs, **kwargs)
hidden_states = outputs.last_hidden_state
loss = None
logits = self.lm_head(hidden_states)
logits = logits / self.config.text_config.logits_scaling
if labels is not None:
loss = self.loss_function(
logits,
labels,
vocab_size=self.config.text_config.vocab_size,
**kwargs,
)
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
logits = logits[:, -logits_to_keep:, :]
return LlavaNextCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
if _V5:
is_first = kwargs.get("is_first_iteration", False)
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
**kwargs,
)
else:
is_first = cache_position[0] == 0 if cache_position is not None else True
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
model_inputs = self._init_hybrid_cache(**model_inputs)
if is_first:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes
return model_inputs
def _init_hybrid_cache(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
"""Handle HybridMambaAttentionDynamicCache for GraniteMoeHybrid language model."""
empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values.get_seq_length() == 0)
if not empty_past_kv and not _V5:
if (
inputs_embeds is not None
or cache_position[-1] >= input_ids.shape[1]
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
elif use_cache and empty_past_kv:
past_key_values = HybridMambaAttentionDynamicCache(
self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device
)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if not empty_past_kv:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and empty_past_kv:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
if not _V5:
model_inputs["cache_position"] = cache_position
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
return model_inputs
class Granite4VisionModel(LlavaNextPreTrainedModel):
config_class = Granite4VisionConfig
def __init__(self, config: Granite4VisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.spatial_projectors = None
assert config.deepstack_layer_map is not None
assert config.downsample_rate is not None
self.downsample_rate = config.downsample_rate
# Deepstack projectors: one per (vision_layer, llm_layer) pair
self.layerwise_projectors = nn.ModuleList([
WindowQFormerDownsampler(config)
for _ in range(len(config.deepstack_layer_map))
])
# Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR)
if config.use_spatial_sampling:
self.spatial_projectors = nn.ModuleList([
WindowQFormerDownsampler(config, spatial_offset=i)
for i in range(4)
])
self.image_newline = None
if config.use_image_newline_parameter:
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = getattr(self.config, "pad_token_id", None) or -1
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def pack_and_unpad_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`list[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if self.layerwise_projectors is not None:
ds_rate = Fraction(self.downsample_rate)
height = int(height * ds_rate)
width = int(width * ds_rate)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
return new_image_features, feature_lens
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Extract image features via deepstack (multi-layer) and spatial sampling projections.
Runs the vision tower once, then:
1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map,
extracts features from that vision layer, downsamples via interpolation + QFormer,
and pairs them with the target LLM layer.
2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial
offset groups (TL, TR, BL, BR), each targeting a different LLM layer.
Args:
pixel_values: Image tensors of shape (batch, num_patches, C, H, W) or (N, C, H, W).
image_sizes: Actual image sizes (num_images, 2).
vision_feature_layer: Unused (kept for API compatibility).
vision_feature_select_strategy: "default" (remove CLS) or "full".
Returns:
List of (llm_layer_idx, packed_features) tuples for injection during forward pass.
"""
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# Deepstack features: extract from multiple vision layers, downsample via interpolation
all_features = []
for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map):
selected_feature = vision_outputs.hidden_states[vision_layer]
if vision_feature_select_strategy == "default":
selected_feature = selected_feature[:, 1:]
projected_features = self.layerwise_projectors[projection_idx](selected_feature)
projected_features = torch.split(projected_features, image_num_patches, dim=0)
packed_features, _ = self.pack_and_unpad_image_features(
projected_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
all_features.append((llm_layer, packed_features))
# Spatial features: extract 4 offset groups from a single vision layer
if self.config.use_spatial_sampling:
spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer]
if vision_feature_select_strategy == "default":
spatial_feature = spatial_feature[:, 1:]
for group_idx, llm_layer in enumerate(self.config.spatial_target_layers):
projected_group = self.spatial_projectors[group_idx](spatial_feature)
projected_group_split = torch.split(projected_group, image_num_patches, dim=0)
packed_group, _ = self.pack_and_unpad_image_features(
projected_group_split,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
all_features.append((llm_layer, packed_group))
return all_features
def get_image_token_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
):
"""
Build a boolean mask over inputs_embeds marking positions of <image> tokens,
and verify that the count matches the number of image feature vectors.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
)
return special_image_mask
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, LlavaNextModelOutputWithPast]:
cache_position = kwargs.pop("cache_position", None)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# Extract deepstack + spatial features and prepare for layer-by-layer injection
deepstack_features = []
vision_mask = None
image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
for idx, (llm_layer_idx, packed_features) in enumerate(image_features):
concat_features = torch.cat(packed_features, dim=0).to(
inputs_embeds.device, inputs_embeds.dtype
)
if idx == 0:
vision_mask = self.get_image_token_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=concat_features
)
inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0)
deepstack_features.append((llm_layer_idx, concat_features))
# Custom forward pass with vision injection at specific LLM layers
hidden_states = inputs_embeds * self.language_model.embedding_multiplier
if _V5:
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
).unsqueeze(0)
causal_mask = create_causal_mask(
config=self.language_model.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
mamba_mask = self.language_model._update_mamba_mask(attention_mask, past_key_values)
else:
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.language_model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
mamba_mask = self.language_model._update_mamba_mask(attention_mask, cache_position)
position_embeddings = None
if self.language_model.rotary_emb is not None:
position_embeddings = self.language_model.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
# Layer-by-layer forward with vision injection
for layer_idx, decoder_layer in enumerate(self.language_model.layers):
# Inject vision features at this layer if configured
for target_layer, features_for_layer in deepstack_features:
if layer_idx == target_layer:
hidden_states = hidden_states.masked_scatter(
vision_mask,
(hidden_states[vision_mask] + features_for_layer.flatten()).view(-1)
)
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_kwargs = dict(
attention_mask=layer_mask,
past_key_values=past_key_values,
use_cache=use_cache,
position_embeddings=position_embeddings,
)
if not _V5:
layer_kwargs["output_attentions"] = output_attentions
layer_kwargs["cache_position"] = cache_position
layer_outputs = decoder_layer(hidden_states, **layer_kwargs, **kwargs)
# v5 decoder layers return a bare tensor; v4 returns a tuple
if isinstance(layer_outputs, torch.Tensor):
hidden_states = layer_outputs
else:
hidden_states = layer_outputs[0]
if output_attentions and layer_outputs[1] is not None:
all_self_attns += (layer_outputs[1],)
hidden_states = self.language_model.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if past_key_values and not past_key_values.has_previous_state:
past_key_values.has_previous_state = True
return LlavaNextModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
image_hidden_states=image_features if pixel_values is not None else None,
)