File size: 27,193 Bytes
5e10c51
 
 
a0bef21
5e10c51
a0bef21
5e10c51
b2adc46
5e10c51
 
 
a0bef21
b2adc46
 
 
 
5e10c51
 
 
 
a0bef21
5e10c51
 
 
 
 
 
 
a0bef21
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b4591e
 
 
 
 
 
 
 
 
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
 
 
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e10c51
b2adc46
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
5e10c51
b2adc46
5e10c51
 
 
 
 
 
 
b2adc46
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
 
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
5e10c51
 
 
 
b2adc46
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
 
 
 
 
 
 
 
 
 
 
5e10c51
b2adc46
 
 
 
 
 
 
 
 
 
 
 
 
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2adc46
5e10c51
 
 
 
 
b2adc46
 
 
 
 
 
 
 
 
 
 
 
5e10c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
import math
from fractions import Fraction
from typing import Optional, Union

import numpy as np
import torch
from torch import nn
import transformers
from transformers import (
    AutoModel,
    LlavaNextForConditionalGeneration,
)

_V5 = int(transformers.__version__.split(".")[0]) >= 5
if _V5:
    from transformers.masking_utils import create_causal_mask
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
    HybridMambaAttentionDynamicCache,
)
from transformers.models.llava_next.modeling_llava_next import (
    LlavaNextCausalLMOutputWithPast,
    LlavaNextModelOutputWithPast,
    LlavaNextPreTrainedModel,
    get_anyres_image_grid_shape,
    image_size_to_num_patches,
    unpad_image,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple, logging

from .configuration import Granite4VisionConfig
from .downsampling import WindowQFormerDownsampler

logger = logging.get_logger(__name__)


class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
    config_class = Granite4VisionConfig

    def __init__(self, config: Granite4VisionConfig):
        LlavaNextPreTrainedModel.__init__(self, config)

        self.model = Granite4VisionModel(config)

        self.lm_head = nn.Linear(
            config.text_config.hidden_size, config.text_config.vocab_size, bias=False
        )

        self.post_init()

    def merge_lora_adapters(self):
        """Merge LoRA adapter weights into base weights in-place and disable adapter toggling."""
        from peft.tuners.tuners_utils import BaseTunerLayer
        for _, module in self.named_modules():
            if isinstance(module, BaseTunerLayer):
                module.merge()
        self._hf_peft_config_loaded = False
        return self

    def generate(self, *args, **kwargs) -> torch.LongTensor:
        # When loaded with a LoRA adapter, disable the adapter for text-only
        # inputs (no pixel_values) so the base LLM runs standalone.
        pixel_values = kwargs.get("pixel_values", None)
        if hasattr(self, "_hf_peft_config_loaded") and self._hf_peft_config_loaded:
            if pixel_values is not None:
                self.enable_adapters()
            else:
                self.disable_adapters()
        return super().generate(*args, **kwargs)

    @can_return_tuple
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[Union[int, list[int]]] = None,
        vision_feature_select_strategy: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Union[tuple, LlavaNextCausalLMOutputWithPast]:
        cache_position = kwargs.pop("cache_position", None)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        vision_feature_layer = (
            vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
        )
        vision_feature_select_strategy = (
            vision_feature_select_strategy
            if vision_feature_select_strategy is not None
            else self.config.vision_feature_select_strategy
        )

        model_kwargs = dict(
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            vision_feature_layer=vision_feature_layer,
            vision_feature_select_strategy=vision_feature_select_strategy,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        if not _V5:
            model_kwargs["cache_position"] = cache_position
        outputs = self.model(input_ids, **model_kwargs, **kwargs)

        hidden_states = outputs.last_hidden_state

        loss = None
        logits = self.lm_head(hidden_states)
        logits = logits / self.config.text_config.logits_scaling
        if labels is not None:
            loss = self.loss_function(
                logits,
                labels,
                vocab_size=self.config.text_config.vocab_size,
                **kwargs,
            )

        if isinstance(logits_to_keep, int) and logits_to_keep > 0:
            logits = logits[:, -logits_to_keep:, :]

        return LlavaNextCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            image_hidden_states=outputs.image_hidden_states,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        pixel_values=None,
        image_sizes=None,
        attention_mask=None,
        cache_position=None,
        logits_to_keep=None,
        **kwargs,
    ):
        if _V5:
            is_first = kwargs.get("is_first_iteration", False)
            model_inputs = super().prepare_inputs_for_generation(
                input_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                logits_to_keep=logits_to_keep,
                **kwargs,
            )
        else:
            is_first = cache_position[0] == 0 if cache_position is not None else True
            model_inputs = super().prepare_inputs_for_generation(
                input_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                cache_position=cache_position,
                logits_to_keep=logits_to_keep,
                **kwargs,
            )
        model_inputs = self._init_hybrid_cache(**model_inputs)
        if is_first:
            model_inputs["pixel_values"] = pixel_values
            model_inputs["image_sizes"] = image_sizes

        return model_inputs

    def _init_hybrid_cache(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        **kwargs,
    ):
        """Handle HybridMambaAttentionDynamicCache for GraniteMoeHybrid language model."""
        empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values.get_seq_length() == 0)

        if not empty_past_kv and not _V5:
            if (
                inputs_embeds is not None
                or cache_position[-1] >= input_ids.shape[1]
            ):
                input_ids = input_ids[:, -cache_position.shape[0] :]
            elif input_ids.shape[1] != cache_position.shape[0]:
                input_ids = input_ids[:, cache_position]
        elif use_cache and empty_past_kv:
            past_key_values = HybridMambaAttentionDynamicCache(
                self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device
            )

        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if not empty_past_kv:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        if inputs_embeds is not None and empty_past_kv:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids.contiguous()}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
            }
        )
        if not _V5:
            model_inputs["cache_position"] = cache_position

        for key, value in kwargs.items():
            if key not in model_inputs:
                model_inputs[key] = value

        return model_inputs


class Granite4VisionModel(LlavaNextPreTrainedModel):
    config_class = Granite4VisionConfig

    def __init__(self, config: Granite4VisionConfig):
        super().__init__(config)
        self.vision_tower = AutoModel.from_config(config.vision_config)
        self.spatial_projectors = None

        assert config.deepstack_layer_map is not None
        assert config.downsample_rate is not None

        self.downsample_rate = config.downsample_rate

        # Deepstack projectors: one per (vision_layer, llm_layer) pair
        self.layerwise_projectors = nn.ModuleList([
            WindowQFormerDownsampler(config)
            for _ in range(len(config.deepstack_layer_map))
        ])

        # Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR)
        if config.use_spatial_sampling:
            self.spatial_projectors = nn.ModuleList([
                WindowQFormerDownsampler(config, spatial_offset=i)
                for i in range(4)
            ])

        self.image_newline = None
        if config.use_image_newline_parameter:
            embed_std = 1 / math.sqrt(config.text_config.hidden_size)
            self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)

        self.vocab_size = config.text_config.vocab_size
        self.language_model = AutoModel.from_config(config.text_config)
        self.pad_token_id = getattr(self.config, "pad_token_id", None) or -1
        self.post_init()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)

    def set_decoder(self, decoder):
        self.language_model = decoder

    def get_decoder(self):
        return self.language_model

    def pack_and_unpad_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
        """
        Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.

        Args:
            image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
                List of image feature tensor, each contains all the visual feature of all patches.
            image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
                Actual image size of each images (H, W).
            vision_feature_select_strategy (`str`)
                The feature selection strategy used to select the vision feature from the vision backbone.
            image_newline (`torch.Tensor` of shape `(embed_dim)`)
                New line embedding vector.
        Returns:
            image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
            feature_lens (`list[int]`)
                token length of each image in image_features
        """
        new_image_features = []
        feature_lens = []
        for image_idx, image_feature in enumerate(image_features):
            if image_feature.shape[0] > 1:
                base_image_feature = image_feature[0]
                image_feature = image_feature[1:]
                height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    image_sizes[image_idx],
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                if self.layerwise_projectors is not None:
                    ds_rate = Fraction(self.downsample_rate)
                    height = int(height  * ds_rate)
                    width = int(width  * ds_rate)

                if (
                    np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
                    and vision_feature_select_strategy == "default"
                ):
                    logger.warning_once(
                        "Image feature shape does not line up with the provided patch size. "
                        "You may be using the `default` vision_feature_select_strategy with a"
                        " visual encoder that does not have CLS."
                    )

                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                image_feature = unpad_image(image_feature, image_sizes[image_idx])
                if image_newline is not None:
                    image_feature = torch.cat(
                        (
                            image_feature,
                            image_newline[:, None, None]
                            .expand(*image_feature.shape[:-1], 1)
                            .to(image_feature.device, image_feature.dtype),
                        ),
                        dim=-1,
                    )
                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                image_feature = torch.cat((base_image_feature, image_feature), dim=0)
            else:
                image_feature = image_feature[0]
                if image_newline is not None:
                    image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
            new_image_features.append(image_feature)
            feature_lens.append(image_feature.size(0))
        feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
        return new_image_features, feature_lens

    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        image_sizes: torch.Tensor,
        vision_feature_layer: Optional[Union[int, list[int]]] = None,
        vision_feature_select_strategy: Optional[str] = None,
    ):
        """
        Extract image features via deepstack (multi-layer) and spatial sampling projections.

        Runs the vision tower once, then:
        1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map,
           extracts features from that vision layer, downsamples via interpolation + QFormer,
           and pairs them with the target LLM layer.
        2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial
           offset groups (TL, TR, BL, BR), each targeting a different LLM layer.

        Args:
            pixel_values: Image tensors of shape (batch, num_patches, C, H, W) or (N, C, H, W).
            image_sizes: Actual image sizes (num_images, 2).
            vision_feature_layer: Unused (kept for API compatibility).
            vision_feature_select_strategy: "default" (remove CLS) or "full".
        Returns:
            List of (llm_layer_idx, packed_features) tuples for injection during forward pass.
        """
        vision_feature_select_strategy = (
            vision_feature_select_strategy
            if vision_feature_select_strategy is not None
            else self.config.vision_feature_select_strategy
        )

        image_num_patches = [
            image_size_to_num_patches(
                image_size=imsize,
                grid_pinpoints=self.config.image_grid_pinpoints,
                patch_size=self.config.vision_config.image_size,
            )
            for imsize in image_sizes
        ]

        if pixel_values.dim() == 5:
            _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
            pixel_values = torch.cat(_pixel_values_list, dim=0)
        elif pixel_values.dim() != 4:
            raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")

        vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True)

        # Deepstack features: extract from multiple vision layers, downsample via interpolation
        all_features = []
        for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map):
            selected_feature = vision_outputs.hidden_states[vision_layer]

            if vision_feature_select_strategy == "default":
                selected_feature = selected_feature[:, 1:]

            projected_features = self.layerwise_projectors[projection_idx](selected_feature)
            projected_features = torch.split(projected_features, image_num_patches, dim=0)

            packed_features, _ = self.pack_and_unpad_image_features(
                projected_features,
                image_sizes,
                vision_feature_select_strategy=vision_feature_select_strategy,
                image_newline=self.image_newline,
            )

            all_features.append((llm_layer, packed_features))

        # Spatial features: extract 4 offset groups from a single vision layer
        if self.config.use_spatial_sampling:
            spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer]

            if vision_feature_select_strategy == "default":
                spatial_feature = spatial_feature[:, 1:]

            for group_idx, llm_layer in enumerate(self.config.spatial_target_layers):
                projected_group = self.spatial_projectors[group_idx](spatial_feature)
                projected_group_split = torch.split(projected_group, image_num_patches, dim=0)

                packed_group, _ = self.pack_and_unpad_image_features(
                    projected_group_split,
                    image_sizes,
                    vision_feature_select_strategy=vision_feature_select_strategy,
                    image_newline=self.image_newline,
                )

                all_features.append((llm_layer, packed_group))

        return all_features

    def get_image_token_mask(
        self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
    ):
        """
        Build a boolean mask over inputs_embeds marking positions of <image> tokens,
        and verify that the count matches the number of image feature vectors.
        """
        if input_ids is None:
            special_image_mask = inputs_embeds == self.get_input_embeddings()(
                torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
            )
            special_image_mask = special_image_mask.all(-1)
        else:
            special_image_mask = input_ids == self.config.image_token_id

        n_image_tokens = special_image_mask.sum()
        special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        if inputs_embeds[special_image_mask].numel() != image_features.numel():
            raise ValueError(
                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
            )
        return special_image_mask

    @can_return_tuple
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[Union[int, list[int]]] = None,
        vision_feature_select_strategy: Optional[str] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[tuple, LlavaNextModelOutputWithPast]:
        cache_position = kwargs.pop("cache_position", None)
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict
        vision_feature_layer = (
            vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
        )
        vision_feature_select_strategy = (
            vision_feature_select_strategy
            if vision_feature_select_strategy is not None
            else self.config.vision_feature_select_strategy
        )

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        # Extract deepstack + spatial features and prepare for layer-by-layer injection
        deepstack_features = []
        vision_mask = None
        image_features = None
        if pixel_values is not None and pixel_values.size(0) > 0:
            image_features = self.get_image_features(
                pixel_values,
                image_sizes,
                vision_feature_layer=vision_feature_layer,
                vision_feature_select_strategy=vision_feature_select_strategy,
            )

            for idx, (llm_layer_idx, packed_features) in enumerate(image_features):
                concat_features = torch.cat(packed_features, dim=0).to(
                    inputs_embeds.device, inputs_embeds.dtype
                )
                if idx == 0:
                    vision_mask = self.get_image_token_mask(
                        input_ids, inputs_embeds=inputs_embeds, image_features=concat_features
                    )
                    inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0)
                deepstack_features.append((llm_layer_idx, concat_features))

        # Custom forward pass with vision injection at specific LLM layers
        hidden_states = inputs_embeds * self.language_model.embedding_multiplier

        if _V5:
            if position_ids is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                position_ids = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                ).unsqueeze(0)
            causal_mask = create_causal_mask(
                config=self.language_model.config,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
            )
            mamba_mask = self.language_model._update_mamba_mask(attention_mask, past_key_values)
        else:
            if cache_position is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0)
            causal_mask = self.language_model._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
            mamba_mask = self.language_model._update_mamba_mask(attention_mask, cache_position)

        position_embeddings = None
        if self.language_model.rotary_emb is not None:
            position_embeddings = self.language_model.rotary_emb(hidden_states, position_ids)

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # Layer-by-layer forward with vision injection
        for layer_idx, decoder_layer in enumerate(self.language_model.layers):
            # Inject vision features at this layer if configured
            for target_layer, features_for_layer in deepstack_features:
                if layer_idx == target_layer:
                    hidden_states = hidden_states.masked_scatter(
                        vision_mask,
                        (hidden_states[vision_mask] + features_for_layer.flatten()).view(-1)
                    )

            layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask

            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_kwargs = dict(
                attention_mask=layer_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                position_embeddings=position_embeddings,
            )
            if not _V5:
                layer_kwargs["output_attentions"] = output_attentions
                layer_kwargs["cache_position"] = cache_position
            layer_outputs = decoder_layer(hidden_states, **layer_kwargs, **kwargs)

            # v5 decoder layers return a bare tensor; v4 returns a tuple
            if isinstance(layer_outputs, torch.Tensor):
                hidden_states = layer_outputs
            else:
                hidden_states = layer_outputs[0]
                if output_attentions and layer_outputs[1] is not None:
                    all_self_attns += (layer_outputs[1],)

        hidden_states = self.language_model.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if past_key_values and not past_key_values.has_previous_state:
            past_key_values.has_previous_state = True

        return LlavaNextModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            image_hidden_states=image_features if pixel_values is not None else None,
        )