Avihu commited on
Commit
b2adc46
·
verified ·
1 Parent(s): d536e5e

Upload modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +77 -44
modeling.py CHANGED
@@ -5,10 +5,15 @@ from typing import Optional, Union
5
  import numpy as np
6
  import torch
7
  from torch import nn
 
8
  from transformers import (
9
  AutoModel,
10
  LlavaNextForConditionalGeneration,
11
  )
 
 
 
 
12
  from transformers.cache_utils import Cache, DynamicCache
13
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
  from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
@@ -81,10 +86,10 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
81
  use_cache: Optional[bool] = None,
82
  output_attentions: Optional[bool] = None,
83
  output_hidden_states: Optional[bool] = None,
84
- cache_position: Optional[torch.LongTensor] = None,
85
  logits_to_keep: Union[int, torch.Tensor] = 0,
86
  **kwargs: Unpack[TransformersKwargs],
87
  ) -> Union[tuple, LlavaNextCausalLMOutputWithPast]:
 
88
 
89
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
90
  output_hidden_states = (
@@ -99,8 +104,7 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
99
  else self.config.vision_feature_select_strategy
100
  )
101
 
102
- outputs = self.model(
103
- input_ids,
104
  pixel_values=pixel_values,
105
  image_sizes=image_sizes,
106
  vision_feature_layer=vision_feature_layer,
@@ -113,9 +117,10 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
113
  output_attentions=output_attentions,
114
  output_hidden_states=output_hidden_states,
115
  return_dict=True,
116
- cache_position=cache_position,
117
- **kwargs,
118
  )
 
 
 
119
 
120
  hidden_states = outputs.last_hidden_state
121
 
@@ -154,17 +159,29 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
154
  logits_to_keep=None,
155
  **kwargs,
156
  ):
157
- model_inputs = super().prepare_inputs_for_generation(
158
- input_ids,
159
- past_key_values=past_key_values,
160
- inputs_embeds=inputs_embeds,
161
- attention_mask=attention_mask,
162
- cache_position=cache_position,
163
- logits_to_keep=logits_to_keep,
164
- **kwargs,
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
166
  model_inputs = self._init_hybrid_cache(**model_inputs)
167
- if cache_position[0] == 0:
168
  model_inputs["pixel_values"] = pixel_values
169
  model_inputs["image_sizes"] = image_sizes
170
 
@@ -182,9 +199,9 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
182
  **kwargs,
183
  ):
184
  """Handle HybridMambaAttentionDynamicCache for GraniteMoeHybrid language model."""
185
- empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values[0][0] is None)
186
 
187
- if not empty_past_kv:
188
  if (
189
  inputs_embeds is not None
190
  or cache_position[-1] >= input_ids.shape[1]
@@ -192,7 +209,7 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
192
  input_ids = input_ids[:, -cache_position.shape[0] :]
193
  elif input_ids.shape[1] != cache_position.shape[0]:
194
  input_ids = input_ids[:, cache_position]
195
- elif use_cache:
196
  past_key_values = HybridMambaAttentionDynamicCache(
197
  self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device
198
  )
@@ -214,9 +231,10 @@ class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
214
  "past_key_values": past_key_values,
215
  "use_cache": use_cache,
216
  "attention_mask": attention_mask,
217
- "cache_position": cache_position,
218
  }
219
  )
 
 
220
 
221
  for key, value in kwargs.items():
222
  if key not in model_inputs:
@@ -258,7 +276,7 @@ class Granite4VisionModel(LlavaNextPreTrainedModel):
258
 
259
  self.vocab_size = config.text_config.vocab_size
260
  self.language_model = AutoModel.from_config(config.text_config)
261
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
262
  self.post_init()
263
 
264
  def get_input_embeddings(self):
@@ -473,14 +491,14 @@ class Granite4VisionModel(LlavaNextPreTrainedModel):
473
  output_attentions: Optional[bool] = None,
474
  output_hidden_states: Optional[bool] = None,
475
  return_dict: Optional[bool] = None,
476
- cache_position: Optional[torch.LongTensor] = None,
477
  **kwargs: Unpack[FlashAttentionKwargs],
478
  ) -> Union[tuple, LlavaNextModelOutputWithPast]:
 
479
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
  output_hidden_states = (
481
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
  )
483
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
484
  vision_feature_layer = (
485
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
486
  )
@@ -522,19 +540,31 @@ class Granite4VisionModel(LlavaNextPreTrainedModel):
522
  # Custom forward pass with vision injection at specific LLM layers
523
  hidden_states = inputs_embeds * self.language_model.embedding_multiplier
524
 
525
- if cache_position is None:
526
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
527
- cache_position = torch.arange(
528
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
 
 
 
 
 
529
  )
530
-
531
- if position_ids is None:
532
- position_ids = cache_position.unsqueeze(0)
533
-
534
- causal_mask = self.language_model._update_causal_mask(
535
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
536
- )
537
- mamba_mask = self.language_model._update_mamba_mask(attention_mask, cache_position)
 
 
 
 
 
538
 
539
  position_embeddings = None
540
  if self.language_model.rotary_emb is not None:
@@ -558,21 +588,24 @@ class Granite4VisionModel(LlavaNextPreTrainedModel):
558
  if output_hidden_states:
559
  all_hidden_states += (hidden_states,)
560
 
561
- layer_outputs = decoder_layer(
562
- hidden_states,
563
  attention_mask=layer_mask,
564
  past_key_values=past_key_values,
565
- output_attentions=output_attentions,
566
  use_cache=use_cache,
567
- cache_position=cache_position,
568
  position_embeddings=position_embeddings,
569
- **kwargs,
570
  )
571
-
572
- hidden_states = layer_outputs[0]
573
-
574
- if output_attentions and layer_outputs[1] is not None:
575
- all_self_attns += (layer_outputs[1],)
 
 
 
 
 
 
 
576
 
577
  hidden_states = self.language_model.norm(hidden_states)
578
 
 
5
  import numpy as np
6
  import torch
7
  from torch import nn
8
+ import transformers
9
  from transformers import (
10
  AutoModel,
11
  LlavaNextForConditionalGeneration,
12
  )
13
+
14
+ _V5 = int(transformers.__version__.split(".")[0]) >= 5
15
+ if _V5:
16
+ from transformers.masking_utils import create_causal_mask
17
  from transformers.cache_utils import Cache, DynamicCache
18
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
  from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
 
86
  use_cache: Optional[bool] = None,
87
  output_attentions: Optional[bool] = None,
88
  output_hidden_states: Optional[bool] = None,
 
89
  logits_to_keep: Union[int, torch.Tensor] = 0,
90
  **kwargs: Unpack[TransformersKwargs],
91
  ) -> Union[tuple, LlavaNextCausalLMOutputWithPast]:
92
+ cache_position = kwargs.pop("cache_position", None)
93
 
94
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
95
  output_hidden_states = (
 
104
  else self.config.vision_feature_select_strategy
105
  )
106
 
107
+ model_kwargs = dict(
 
108
  pixel_values=pixel_values,
109
  image_sizes=image_sizes,
110
  vision_feature_layer=vision_feature_layer,
 
117
  output_attentions=output_attentions,
118
  output_hidden_states=output_hidden_states,
119
  return_dict=True,
 
 
120
  )
121
+ if not _V5:
122
+ model_kwargs["cache_position"] = cache_position
123
+ outputs = self.model(input_ids, **model_kwargs, **kwargs)
124
 
125
  hidden_states = outputs.last_hidden_state
126
 
 
159
  logits_to_keep=None,
160
  **kwargs,
161
  ):
162
+ if _V5:
163
+ is_first = kwargs.get("is_first_iteration", False)
164
+ model_inputs = super().prepare_inputs_for_generation(
165
+ input_ids,
166
+ past_key_values=past_key_values,
167
+ inputs_embeds=inputs_embeds,
168
+ attention_mask=attention_mask,
169
+ logits_to_keep=logits_to_keep,
170
+ **kwargs,
171
+ )
172
+ else:
173
+ is_first = cache_position[0] == 0 if cache_position is not None else True
174
+ model_inputs = super().prepare_inputs_for_generation(
175
+ input_ids,
176
+ past_key_values=past_key_values,
177
+ inputs_embeds=inputs_embeds,
178
+ attention_mask=attention_mask,
179
+ cache_position=cache_position,
180
+ logits_to_keep=logits_to_keep,
181
+ **kwargs,
182
+ )
183
  model_inputs = self._init_hybrid_cache(**model_inputs)
184
+ if is_first:
185
  model_inputs["pixel_values"] = pixel_values
186
  model_inputs["image_sizes"] = image_sizes
187
 
 
199
  **kwargs,
200
  ):
201
  """Handle HybridMambaAttentionDynamicCache for GraniteMoeHybrid language model."""
202
+ empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values.get_seq_length() == 0)
203
 
204
+ if not empty_past_kv and not _V5:
205
  if (
206
  inputs_embeds is not None
207
  or cache_position[-1] >= input_ids.shape[1]
 
209
  input_ids = input_ids[:, -cache_position.shape[0] :]
210
  elif input_ids.shape[1] != cache_position.shape[0]:
211
  input_ids = input_ids[:, cache_position]
212
+ elif use_cache and empty_past_kv:
213
  past_key_values = HybridMambaAttentionDynamicCache(
214
  self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device
215
  )
 
231
  "past_key_values": past_key_values,
232
  "use_cache": use_cache,
233
  "attention_mask": attention_mask,
 
234
  }
235
  )
236
+ if not _V5:
237
+ model_inputs["cache_position"] = cache_position
238
 
239
  for key, value in kwargs.items():
240
  if key not in model_inputs:
 
276
 
277
  self.vocab_size = config.text_config.vocab_size
278
  self.language_model = AutoModel.from_config(config.text_config)
279
+ self.pad_token_id = getattr(self.config, "pad_token_id", None) or -1
280
  self.post_init()
281
 
282
  def get_input_embeddings(self):
 
491
  output_attentions: Optional[bool] = None,
492
  output_hidden_states: Optional[bool] = None,
493
  return_dict: Optional[bool] = None,
 
494
  **kwargs: Unpack[FlashAttentionKwargs],
495
  ) -> Union[tuple, LlavaNextModelOutputWithPast]:
496
+ cache_position = kwargs.pop("cache_position", None)
497
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
498
  output_hidden_states = (
499
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
500
  )
501
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
502
  vision_feature_layer = (
503
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
504
  )
 
540
  # Custom forward pass with vision injection at specific LLM layers
541
  hidden_states = inputs_embeds * self.language_model.embedding_multiplier
542
 
543
+ if _V5:
544
+ if position_ids is None:
545
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
546
+ position_ids = torch.arange(
547
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
548
+ ).unsqueeze(0)
549
+ causal_mask = create_causal_mask(
550
+ config=self.language_model.config,
551
+ inputs_embeds=inputs_embeds,
552
+ attention_mask=attention_mask,
553
+ past_key_values=past_key_values,
554
  )
555
+ mamba_mask = self.language_model._update_mamba_mask(attention_mask, past_key_values)
556
+ else:
557
+ if cache_position is None:
558
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
559
+ cache_position = torch.arange(
560
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
561
+ )
562
+ if position_ids is None:
563
+ position_ids = cache_position.unsqueeze(0)
564
+ causal_mask = self.language_model._update_causal_mask(
565
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
566
+ )
567
+ mamba_mask = self.language_model._update_mamba_mask(attention_mask, cache_position)
568
 
569
  position_embeddings = None
570
  if self.language_model.rotary_emb is not None:
 
588
  if output_hidden_states:
589
  all_hidden_states += (hidden_states,)
590
 
591
+ layer_kwargs = dict(
 
592
  attention_mask=layer_mask,
593
  past_key_values=past_key_values,
 
594
  use_cache=use_cache,
 
595
  position_embeddings=position_embeddings,
 
596
  )
597
+ if not _V5:
598
+ layer_kwargs["output_attentions"] = output_attentions
599
+ layer_kwargs["cache_position"] = cache_position
600
+ layer_outputs = decoder_layer(hidden_states, **layer_kwargs, **kwargs)
601
+
602
+ # v5 decoder layers return a bare tensor; v4 returns a tuple
603
+ if isinstance(layer_outputs, torch.Tensor):
604
+ hidden_states = layer_outputs
605
+ else:
606
+ hidden_states = layer_outputs[0]
607
+ if output_attentions and layer_outputs[1] is not None:
608
+ all_self_attns += (layer_outputs[1],)
609
 
610
  hidden_states = self.language_model.norm(hidden_states)
611