granite-4.0-3b-vision / granite4_vision.py
artemspector
Fix full-merge LoRA crash with tensor parallelism (multi-GPU)
a50f5b5
"""vLLM implementation of Granite 4 Vision (ibm-granite/granite-4.0-3b-vision).
Fresh implementation based on the HF reference model, following the patterns
documented in hf_to_vllm_porting_guide.md.
LoRA support:
- Full merge (--hf-overrides '{"adapter_path": "..."}') merges LM-only LoRA
deltas into base weights at load time.
- Native LoRA (--enable-lora --default-mm-loras) lets vLLM runtime serve
LM LoRA deltas per-request.
Both modes expect a LM-only adapter (no modules_to_save).
"""
import json
import math
import os
from collections.abc import Iterable, Mapping
from fractions import Fraction
import torch
import torch.nn as nn
from safetensors.torch import load_file
from transformers import BatchFeature
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape,
image_size_to_num_patches,
unpad_image,
)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from vllm.model_executor.models.llava import (
LlavaDummyInputsBuilder,
init_vision_tower_for_llava,
)
from vllm.model_executor.models.llava_next import (
BaseLlavaNextMultiModalProcessor,
LlavaNextProcessingInfo,
LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs,
LlavaNextImageInputs,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Downsampler modules (translated from HF downsampling.py)
# ---------------------------------------------------------------------------
class InterpolateDownsampler:
"""Spatial downsampling via area interpolation."""
def __init__(self, config, mode="area"):
self.orig_image_side = (
config.vision_config.image_size // config.vision_config.patch_size
)
self.new_image_side = int(
self.orig_image_side * Fraction(config.downsample_rate)
)
self.mode = mode
def __call__(self, image_features: torch.Tensor) -> torch.Tensor:
batch_size, _, dim = image_features.size()
up_shape = [batch_size, self.orig_image_side, self.orig_image_side, dim]
large = image_features.view(up_shape).permute(0, 3, 1, 2)
small = torch.nn.functional.interpolate(
large,
size=(self.new_image_side, self.new_image_side),
mode=self.mode,
)
return small.permute(0, 2, 3, 1).flatten(1, 2)
class SpatialOffsetDownsampler:
"""Sample one position from each 2x2 block (offset 0-3 = TL/TR/BL/BR)."""
def __init__(self, config, offset: int = 0):
self.orig_image_side = (
config.vision_config.image_size // config.vision_config.patch_size
)
self.new_image_side = self.orig_image_side // 2
offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
self.offset_h, self.offset_w = offsets[offset]
def __call__(self, image_features: torch.Tensor) -> torch.Tensor:
B, _, C = image_features.shape
features_2d = image_features.reshape(
B, self.orig_image_side, self.orig_image_side, C
)
n = self.new_image_side
blocks = features_2d.reshape(B, n, 2, n, 2, C)
sampled = blocks[:, :, self.offset_h, :, self.offset_w, :]
return sampled.reshape(B, -1, C)
class WindowQFormerDownsampler(nn.Module):
"""Window-based QFormer downsampler (matches HF downsampling.py exactly)."""
def __init__(self, config, spatial_offset=None):
super().__init__()
llm_hidden_size = config.text_config.hidden_size
vision_hidden_size = config.vision_config.hidden_size
self.dropout = nn.Dropout(config.projector_dropout)
if spatial_offset is not None:
self.downsampler = SpatialOffsetDownsampler(config, offset=spatial_offset)
else:
self.downsampler = InterpolateDownsampler(config)
qformer_config = Blip2QFormerConfig(
hidden_size=vision_hidden_size,
num_attention_heads=vision_hidden_size // 64,
intermediate_size=3072,
num_hidden_layers=1,
encoder_hidden_size=vision_hidden_size,
cross_attention_frequency=1,
max_position_embeddings=2048,
use_qformer_text_input=False,
)
self.qformer = Blip2QFormerModel(qformer_config)
self.image_side = (
config.vision_config.image_size // config.vision_config.patch_size
)
q, w = config.downsample_rate.split("/")
self.query_side, self.window_side = int(q), int(w)
self.query_length = self.query_side**2
embed_std = 1 / math.sqrt(vision_hidden_size)
self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
self.query = nn.Parameter(
torch.randn(1, self.query_length, vision_hidden_size) * embed_std
)
self.image_positions = nn.Parameter(
torch.randn(1, self.window_side**2, vision_hidden_size) * embed_std
)
self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)
def _win(self, x: torch.Tensor, side: int, win: int) -> torch.Tensor:
"""(B, side*side, C) → (B*n*n, win*win, C) where n=side//win."""
B, _, C = x.shape
n = side // win
return (
x.view(B, side, side, C)
.view(B, n, win, n, win, C)
.transpose(2, 3)
.flatten(0, 2)
.flatten(1, 2)
)
def _unwin(self, xw: torch.Tensor, n: int, win: int) -> torch.Tensor:
"""(B*n*n, win*win, C) → (B, (n*win)^2, C)."""
Bnn, _, C = xw.shape
B = Bnn // (n * n)
side = n * win
return (
xw.view(B, n, n, win, win, C)
.transpose(2, 3)
.contiguous()
.view(B, side, side, C)
.flatten(1, 2)
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
B, HW, C = image_features.shape
assert HW == self.image_side * self.image_side
n = self.image_side // self.window_side
image_features = self.norm(image_features)
enc = self._win(image_features, self.image_side, self.window_side)
downsampled = self.downsampler(image_features)
new_side = n * self.query_side
downsampled_w = self._win(downsampled, new_side, self.query_side)
query_embeds = self.query + downsampled_w
encoder_embeds = self.dropout(enc + self.image_positions)
out_w = self.qformer(
query_embeds=query_embeds,
encoder_hidden_states=encoder_embeds,
return_dict=True,
).last_hidden_state
out = self._unwin(out_w, n=n, win=self.query_side)
out = self.dropout(out)
return self.out_linear(out)
# ---------------------------------------------------------------------------
# Processing info / processor (reuses LlavaNext patterns)
# ---------------------------------------------------------------------------
class Granite4VisionProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs):
return self.ctx.get_hf_processor(**kwargs)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
# After QFormer downsampling, patch grid is scaled by downsample_rate
ds_rate = Fraction(hf_config.downsample_rate)
patch_grid = vision_encoder_info.get_patch_grid_length() # 24 for 384/16
downsampled_grid = int(patch_grid * ds_rate) # 12 for rate 4/8
# Base feature: downsampled_grid^2
base_feature_size = downsampled_grid * downsampled_grid
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(image_height, image_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_encoder_info.get_image_size(),
)
(
unpadded_feature_size,
newline_feature_size,
) = self._get_num_unpadded_features(
original_height=image_height,
original_width=image_width,
npatches=downsampled_grid,
num_patch_height=num_patch_height,
num_patch_width=num_patch_width,
)
return unpadded_feature_size + newline_feature_size + base_feature_size
class Granite4VisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[Granite4VisionProcessingInfo]
):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
)
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
@MULTIMODAL_REGISTRY.register_processor(
Granite4VisionMultiModalProcessor,
info=Granite4VisionProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder,
)
class Granite4VisionForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
"""vLLM implementation of Granite 4 Vision.
Architecture:
- SigLIP vision tower → WindowQFormerDownsampler projectors
- Deepstack: 4 vision layers projected and injected at 4 LLM layers
- Spatial: 4 offset groups from last vision layer injected at 4 more LLM layers
- GraniteMoeHybrid language backbone with embedding_multiplier and residual_multiplier
- logits_scaling via LogitsProcessor
The outer model runs the LLM layer loop directly (like HF does) to inject
deepstack features. This avoids wrapping the inner model and keeps weight
loading simple.
LoRA support:
- Full merge: --hf-overrides '{"adapter_path": "path/to/lora"}' merges
LM-only LoRA deltas at load time (W += scaling * B @ A).
- Native LoRA: --enable-lora --default-mm-loras '{"image": "path/to/lora"}'
lets vLLM runtime serve LM LoRA per-request.
Both modes expect a LM-only adapter (no modules_to_save).
"""
# LoRA class attributes (matches GraniteMoeHybridForCausalLM)
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"input_linear": ["input_linear"],
}
embedding_modules = {}
# Weight mapping: HF checkpoint → vLLM parameter names
# HF has: model.language_model.layers.0... → vLLM: language_model.model.layers.0...
# (because GraniteMoeHybridForCausalLM.model = GraniteMoeHybridModel)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.language_model.": "language_model.model.",
"model.layerwise_projectors.": "layerwise_projectors.",
"model.spatial_projectors.": "spatial_projectors.",
"model.image_newline": "image_newline",
"model.vision_tower.": "vision_tower.",
"lm_head.": "language_model.lm_head.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
raise ValueError(f"Only image modality is supported, got {modality}")
def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=["layerwise_projectors", "spatial_projectors"],
tower_model="vision_tower",
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vllm_config = vllm_config
# ----- Vision tower + projectors (marked as tower) -----
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
# image_newline parameter
if getattr(config, "use_image_newline_parameter", False):
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)
)
else:
self.image_newline = None
# Deepstack projectors: one per (vision_layer, llm_layer) pair
self.layerwise_projectors = nn.ModuleList([
WindowQFormerDownsampler(config)
for _ in range(len(config.deepstack_layer_map))
])
# Spatial projectors: 4 offset groups
self.spatial_projectors = None
if getattr(config, "use_spatial_sampling", False):
self.spatial_projectors = nn.ModuleList([
WindowQFormerDownsampler(config, spatial_offset=i)
for i in range(4)
])
# ----- Language model (marked as LM) -----
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
# Store config values we need
self._deepstack_layer_map = config.deepstack_layer_map # [[-19, 9], ...]
self._use_spatial_sampling = getattr(config, "use_spatial_sampling", False)
self._spatial_vision_layer = getattr(config, "spatial_vision_layer", -1)
self._spatial_target_layers = getattr(config, "spatial_target_layers", [])
self._vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", "full"
)
self._downsample_rate = Fraction(config.downsample_rate)
# Deepstack state — set during embed_input_ids, consumed during forward
# list of (llm_layer_idx, features_buffer) where buffer is (N, hidden_size)
self._ds_features: list[tuple[int, torch.Tensor]] = []
self._ds_vision_mask: torch.Tensor | None = None
# ----- Vision feature extraction -----
def _get_vision_hidden_states(
self, pixel_values: torch.Tensor
) -> list[torch.Tensor]:
"""Run vision tower and return all hidden states (including input embeddings).
Uses SiglipEncoder's built-in return_all_hidden_states support.
Returns list[Tensor] where index 0 = embeddings, index i = after layer i-1.
"""
vt = self.vision_tower
vm = vt.vision_model if hasattr(vt, "vision_model") else vt
hidden_states = vm.embeddings(pixel_values)
all_hidden_states = vm.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=True,
)
return all_hidden_states
def _pack_and_unpad_image_features(
self,
image_features: list[torch.Tensor] | tuple[torch.Tensor, ...],
image_sizes: torch.Tensor,
) -> list[torch.Tensor]:
"""Reshape, unpad, and pack image features.
Matches HF Granite4VisionModel.pack_and_unpad_image_features exactly.
"""
config = self.config
ds_rate = self._downsample_rate
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
# Multi-patch: first is base, rest are high-res
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (
config.vision_config.image_size
// config.vision_config.patch_size
)
# After QFormer downsampling
height = int(height * ds_rate)
width = int(width * ds_rate)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
config.image_grid_pinpoints,
config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = (
image_feature.permute(4, 0, 2, 1, 3).contiguous()
.flatten(1, 2)
.flatten(2, 3)
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
if self.image_newline is not None:
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
if self.image_newline is not None:
image_feature = torch.cat(
(image_feature, self.image_newline[None].to(image_feature)),
dim=0,
)
new_image_features.append(image_feature)
return new_image_features
def _get_all_layer_features(
self,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor,
) -> list[tuple[int, list[torch.Tensor]]]:
"""Extract deepstack + spatial features.
Returns list of (llm_layer_idx, [per_image_features, ...]) tuples.
This is the vLLM equivalent of HF's get_image_features.
"""
select_strategy = self._vision_feature_select_strategy
# Count patches per image for splitting
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
# Flatten 5D → 4D if needed
if pixel_values.dim() == 5:
_pv_list = [
pv[:np_]
for pv, np_ in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pv_list, dim=0)
# Run vision tower once, get all hidden states
all_hidden_states = self._get_vision_hidden_states(pixel_values)
all_features = []
# ----- Deepstack features -----
for proj_idx, (vision_layer, llm_layer) in enumerate(
self._deepstack_layer_map
):
selected = all_hidden_states[vision_layer]
if select_strategy == "default":
selected = selected[:, 1:] # remove CLS
projected = self.layerwise_projectors[proj_idx](selected)
projected_split = torch.split(projected, image_num_patches, dim=0)
packed = self._pack_and_unpad_image_features(
projected_split, image_sizes
)
all_features.append((llm_layer, packed))
# ----- Spatial features -----
if self._use_spatial_sampling and self.spatial_projectors is not None:
spatial_hidden = all_hidden_states[self._spatial_vision_layer]
if select_strategy == "default":
spatial_hidden = spatial_hidden[:, 1:]
for group_idx, llm_layer in enumerate(self._spatial_target_layers):
projected = self.spatial_projectors[group_idx](spatial_hidden)
projected_split = torch.split(
projected, image_num_patches, dim=0
)
packed = self._pack_and_unpad_image_features(
projected_split, image_sizes
)
all_features.append((llm_layer, packed))
return all_features
# ----- Multimodal interface -----
def _parse_and_validate_image_input(
self, **kwargs: object
) -> LlavaNextImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={"h": expected_h, "w": expected_w},
)
if image_embeds is not None:
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("Unreachable")
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
"""Convert pixel values → per-image placeholder tensors.
The actual vision features are stored in self._ds_level_features and
injected during the forward loop (like HF does). We return zero-
tensors with the right shape so that _merge_multimodal_embeddings
fills image positions with zeros (matching HF's masked_fill(mask, 0)).
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
if image_input["type"] == "image_embeds":
return [image_input["data"]]
pixel_values = image_input["pixel_values"]
image_sizes = image_input.get("image_sizes")
if isinstance(pixel_values, list):
pixel_values = torch.cat(pixel_values, dim=0)
# Get all (llm_layer, [per_image_features]) pairs
all_features = self._get_all_layer_features(pixel_values, image_sizes)
# Store ALL level features for deepstack injection in forward()
self._ds_level_features = all_features
# Return zero-tensors matching the shape of level-0 features.
# This makes _merge_multimodal_embeddings write zeros at image
# positions (equivalent to HF's inputs_embeds.masked_fill(mask, 0)).
# All real features are injected during the layer loop.
if all_features:
return [
torch.zeros_like(feat)
for feat in all_features[0][1]
]
return []
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Merge text and vision embeddings, apply embedding_multiplier.
HF flow:
1. inputs_embeds = embed_tokens(input_ids)
2. inputs_embeds.masked_fill(vision_mask, 0.0)
3. hidden_states = inputs_embeds * embedding_multiplier
4. layer loop with deepstack injection via masked_scatter
vLLM's GraniteMoeHybridModel.forward:
- if inputs_embeds given: hidden_states = inputs_embeds (NO multiplier)
- if input_ids given: hidden_states = embed(input_ids) * multiplier
So we apply embedding_multiplier here and pass inputs_embeds to forward.
"""
# Access the inner GraniteMoeHybridModel
lm_inner = self.language_model.model # GraniteMoeHybridModel
has_vision = (
multimodal_embeddings is not None
and is_multimodal is not None
and len(multimodal_embeddings) > 0
and is_multimodal.any()
)
if not has_vision:
# Text-only or decode: clear deepstack state
self._ds_features = []
self._ds_vision_mask = None
self._ds_level_features = []
# Apply embedding_multiplier here because forward() always receives
# inputs_embeds and skips the inner model's embed+multiply path.
embeds = lm_inner.embed_input_ids(input_ids)
return embeds * lm_inner.embedding_multiplier
# --- Vision path ---
# HF flow: embed → masked_fill(vision_mask, 0.0) → multiply by embedding_multiplier
# Then layer loop adds vision features via masked_scatter (never in inputs_embeds).
# 1. Get text embeddings
text_embeds = lm_inner.embed_input_ids(input_ids)
# 2. Zero out image positions (HF: inputs_embeds.masked_fill(vision_mask, 0.0))
# _merge_multimodal_embeddings writes our zero-tensors here, same effect.
from vllm.model_executor.models.utils import (
_merge_multimodal_embeddings,
)
_merge_multimodal_embeddings(
inputs_embeds=text_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
# 3. Apply embedding_multiplier to ALL positions (text + vision=0)
embedding_multiplier = lm_inner.embedding_multiplier
inputs_embeds = text_embeds * embedding_multiplier
# 5. Prepare deepstack feature buffers for the layer loop
N = inputs_embeds.size(0)
hidden_size = inputs_embeds.size(1)
prepared = []
for llm_layer, per_image_features in getattr(self, "_ds_level_features", []):
concat_features = torch.cat(per_image_features, dim=0)
buf = torch.zeros(
N, hidden_size,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
buf[is_multimodal] = concat_features.to(dtype=inputs_embeds.dtype)
prepared.append((llm_layer, buf))
self._ds_features = prepared
self._ds_vision_mask = is_multimodal
self._ds_level_features = [] # consumed
return inputs_embeds
# ----- Forward -----
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Forward pass with deepstack injection.
Runs the LLM layer loop directly (like HF Granite4VisionModel.forward)
to inject vision features at target layers via masked addition.
"""
if intermediate_tensors is not None:
inputs_embeds = None
# Access the inner GraniteMoeHybridModel
lm_inner = self.language_model.model
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
# embed_input_ids already applied embedding_multiplier
hidden_states = inputs_embeds
else:
# Text-only path: embed + multiply
hidden_states = lm_inner.embed_input_ids(input_ids)
hidden_states = hidden_states * lm_inner.embedding_multiplier
residual = None
else:
if intermediate_tensors is None:
raise RuntimeError("Intermediate tensors may not be None!")
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
num_tokens = hidden_states.size(0)
# Build O(1) lookup for deepstack features
ds_map: dict[int, torch.Tensor] = {}
for llm_layer_idx, features in self._ds_features:
ds_map[llm_layer_idx] = features
vision_mask = self._ds_vision_mask
if vision_mask is not None:
vision_mask = vision_mask[:num_tokens]
# Run through decoder layers with deepstack injection
for i in range(lm_inner.start_layer, lm_inner.end_layer):
layer = lm_inner.layers[i]
# Inject deepstack features at target layers (before layer forward)
if i in ds_map and vision_mask is not None and vision_mask.any():
features = ds_map[i][:num_tokens]
hidden_states = hidden_states.clone()
hidden_states[vision_mask] = (
hidden_states[vision_mask] + features[vision_mask]
)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states = lm_inner.norm(hidden_states)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
# GraniteMoeHybridForCausalLM.compute_logits uses
# LogitsProcessor(scale=1/logits_scaling)
return self.language_model.compute_logits(hidden_states)
# ----- Full-merge LoRA support -----
# HF→vLLM key prefix mapping (same transforms as hf_to_vllm_mapper)
_ADAPTER_PREFIX_MAP = [
("model.language_model.", "language_model.model."),
]
# vLLM fuses q/k/v_proj into qkv_proj.
_STACKED_PARAMS_MAPPING = [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
@staticmethod
def _peft_to_vllm(peft_key: str) -> str:
"""Strip 'base_model.model.' and apply HF→vLLM prefix mapping."""
name = peft_key
if name.startswith("base_model.model."):
name = name[len("base_model.model."):]
for old_pfx, new_pfx in (
Granite4VisionForConditionalGeneration._ADAPTER_PREFIX_MAP
):
if name.startswith(old_pfx):
name = new_pfx + name[len(old_pfx):]
break
return name
@staticmethod
def _load_adapter(adapter_path: str) -> tuple[dict, dict[str, torch.Tensor]]:
"""Load adapter config and safetensors from a directory or HF hub ID."""
# Resolve HF hub IDs to local cache path
if not os.path.isdir(adapter_path):
from huggingface_hub import snapshot_download
adapter_path = snapshot_download(adapter_path)
config_path = os.path.join(adapter_path, "adapter_config.json")
weights_path = os.path.join(adapter_path, "adapter_model.safetensors")
if not os.path.exists(config_path):
raise FileNotFoundError(f"No adapter_config.json in {adapter_path}")
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"No adapter_model.safetensors in {adapter_path}")
with open(config_path) as f:
config = json.load(f)
weights = load_file(weights_path)
return config, weights
def _merge_lora_deltas(
self,
adapter_config: dict,
adapter_weights: dict[str, torch.Tensor],
) -> int:
"""Merge LM-only LoRA deltas into model weights: W += scaling * B @ A.
Uses _STACKED_PARAMS_MAPPING + module._get_shard_offset_mapping()
to handle packed QKV correctly (works with GQA automatically).
"""
lora_alpha = adapter_config.get("lora_alpha", 1)
lora_r = adapter_config.get("r", 1)
scaling = lora_alpha / lora_r
# Collect lora_A / lora_B by vLLM module key
lora_a: dict[str, torch.Tensor] = {}
lora_b: dict[str, torch.Tensor] = {}
for peft_key, tensor in adapter_weights.items():
if ".lora_A." in peft_key:
module_key = self._peft_to_vllm(
peft_key.replace(".lora_A.weight", ""))
lora_a[module_key] = tensor
elif ".lora_B." in peft_key:
module_key = self._peft_to_vllm(
peft_key.replace(".lora_B.weight", ""))
lora_b[module_key] = tensor
params_dict = dict(self.named_parameters())
modules_dict = dict(self.named_modules())
def _add_delta(name: str, delta: torch.Tensor) -> bool:
# Try stacked/fused params first (qkv_proj)
for fused_name, orig_name, shard_id in self._STACKED_PARAMS_MAPPING:
if orig_name not in name:
continue
fused_param_name = name.replace(orig_name, fused_name)
if fused_param_name not in params_dict:
continue
param = params_dict[fused_param_name]
module_path = fused_param_name.rsplit(".weight", 1)[0]
module = modules_dict.get(module_path)
if module is None:
continue
if hasattr(module, "_get_shard_offset_mapping"):
shard_offset = module._get_shard_offset_mapping(shard_id)
if shard_offset is not None:
# Under TP, the shard_offset and sizes from
# _get_shard_offset_mapping are already TP-local
# (num_heads is per-rank), but delta is full-size.
# Slice delta's dim 0 for this TP rank.
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
shard_size = delta.shape[0] // tp_size
tp_delta = delta.narrow(
0, tp_rank * shard_size, shard_size)
shard = param.data[shard_offset:shard_offset + shard_size]
param.data[shard_offset:shard_offset + shard_size] = (
shard.float() + tp_delta.to(shard.device)
).to(shard.dtype)
return True
# Direct param (o_proj, input_linear, output_linear)
if name in params_dict:
param = params_dict[name]
# Under TP, param is already sharded but delta is full-size.
# Slice delta to match: dim 0 for column-parallel, dim 1 for
# row-parallel.
if delta.shape != param.data.shape:
tp_rank = get_tensor_model_parallel_rank()
for dim in range(delta.dim()):
if delta.shape[dim] != param.data.shape[dim]:
shard_size = param.data.shape[dim]
offset = tp_rank * shard_size
delta = delta.narrow(dim, offset, shard_size)
break
param.data = (param.data.float() + delta.to(param.device)).to(param.dtype)
return True
return False
merge_device = next(self.parameters()).device
merged = 0
for module_key in sorted(lora_a):
if module_key not in lora_b:
logger.warning("LoRA B missing for %s, skipping", module_key)
continue
A = lora_a[module_key].to(merge_device).float()
B = lora_b[module_key].to(merge_device).float()
delta = scaling * (B @ A)
if _add_delta(module_key + ".weight", delta):
merged += 1
else:
logger.warning("LoRA target not found: %s", module_key)
return merged
def _apply_adapter(self) -> None:
"""Full-merge entry point: called when config.adapter_path is set."""
adapter_path = getattr(self.config, "adapter_path", None)
if not adapter_path:
return
logger.info("Full-merge LoRA from %s", adapter_path)
adapter_config, adapter_weights = self._load_adapter(adapter_path)
if adapter_config.get("modules_to_save"):
raise ValueError(
"Adapter has modules_to_save — only LM-only adapters "
"(no modules_to_save) are supported."
)
n = self._merge_lora_deltas(adapter_config, adapter_weights)
logger.info("Merged %d LoRA pairs into base weights", n)
def load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> set[str]:
loader = AutoWeightsLoader(self)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
self._apply_adapter()
return loaded