farzadab commited on
Commit
494d2fb
·
verified ·
1 Parent(s): 009e6a9

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +333 -113
ultravox_model.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
- from typing import Any, Dict, Optional, Set, Tuple, Union
 
3
 
4
  import peft
5
  import torch
@@ -9,6 +10,7 @@ import transformers
9
  import transformers.activations
10
  import transformers.modeling_outputs
11
  import transformers.models
 
12
  from transformers.models.whisper import modeling_whisper as whisper
13
 
14
  # We must use relative import in this directory to allow uploading to HF Hub
@@ -18,7 +20,7 @@ from .ultravox_config import LossFunction
18
  from .ultravox_config import UltravoxConfig
19
 
20
 
21
- class UltravoxModel(transformers.LlamaPreTrainedModel):
22
  """
23
  The Ultravox model which consists of an audio encoder and a language model.
24
 
@@ -34,29 +36,72 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
34
 
35
  config_class = UltravoxConfig
36
  config: UltravoxConfig # for type hinting
37
- _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
38
- # We minimize the weights in state_dict in order to reduce the size of the checkpoint
39
- # The issue is that load_pretrained() uses state_dict() keys to know what keys are expected
40
- # As such we have to tell is to ignore some keys that are not always in the model
41
- _keys_to_ignore_on_load_unexpected = ["audio_tower.*", "language_model.*"]
42
- # Usually we load encoder weights from a pretrained model, so we don't want to load the decoder weights
43
- # Technically we never hit this issue because these keys are already removed from state_dict() however,
44
- # but there's no harm in keeping it here for when we change that behavior.
45
- _keys_to_ignore_on_load_missing = ["audio_tower.*"]
46
 
47
  def __init__(self, config: UltravoxConfig):
48
  super().__init__(config)
 
49
 
50
  self.keep_params: Set[str] = set()
51
  self.vocab_size = config.vocab_size
52
 
53
  self.audio_tower = self._create_audio_tower(config)
54
- self.multi_modal_projector = UltravoxProjector(config)
 
 
 
55
  self.language_model = self._create_language_model(config)
56
 
 
 
 
 
 
 
 
 
 
 
57
  self.loss_config = LossConfig()
58
  self.post_init()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def get_input_embeddings(self):
61
  return self.language_model.get_input_embeddings()
62
 
@@ -103,6 +148,30 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
103
  self.vocab_size = model_embeds.num_embeddings
104
  return model_embeds
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def _compute_kl_loss(
107
  self,
108
  lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
@@ -127,11 +196,12 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
127
  # compute the KL divergence loss between the two models
128
  kl_loss = F.kl_div(
129
  F.log_softmax(
130
- lm_output.logits[labels != -100] / self.loss_config.kl_temperature,
 
131
  dim=-1,
132
  ),
133
  F.softmax(
134
- alt_lm_output.logits[alt_labels != -100]
135
  / self.loss_config.kl_temperature,
136
  dim=-1,
137
  ),
@@ -139,6 +209,24 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
139
  )
140
  return {"loss": kl_loss}
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def forward(
143
  self,
144
  input_ids: torch.Tensor,
@@ -147,7 +235,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
147
  labels: Optional[torch.Tensor] = None,
148
  attention_mask: Optional[torch.Tensor] = None,
149
  audio_token_start_idx: Optional[torch.Tensor] = None,
 
150
  audio_token_len: Optional[torch.Tensor] = None,
 
151
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
152
  # the alt_* fields are needed for KL divergence loss
153
  alt_input_ids: Optional[torch.Tensor] = None,
@@ -178,28 +268,37 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
178
  # B x T -> B x T x D
179
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
180
 
181
- if audio_values is not None:
182
  assert (
183
- audio_token_start_idx is not None and audio_token_len is not None
184
- ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
 
 
 
185
  assert (
186
- len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
187
- ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
188
-
189
- # B x A/3200 x D
 
 
 
 
 
 
190
  audio_tower_output = self.audio_tower.forward(
191
- audio_values
 
192
  ).last_hidden_state
193
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
194
-
195
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
196
 
197
  # combine audio and text embeddings
198
- for i, (audio, start, length) in enumerate(
199
- zip(audio_embeds, audio_token_start_idx, audio_token_len)
200
- ):
201
- length = min(length, audio.shape[0])
202
- inputs_embeds[i, start : start + length] = audio[:length]
203
 
204
  lm_output = self.language_model.forward(
205
  inputs_embeds=inputs_embeds,
@@ -234,6 +333,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
234
  audio_values: Optional[torch.FloatTensor] = None,
235
  audio_token_start_idx: Optional[torch.Tensor] = None,
236
  audio_token_len: Optional[torch.Tensor] = None,
 
 
237
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
238
  attention_mask: Optional[torch.Tensor] = None,
239
  inputs_embeds: Optional[torch.Tensor] = None,
@@ -251,7 +352,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
251
 
252
  # include audio information in model_input only when it is needed during prefilling
253
  # audio_token_start_idx should always be relative to the current cache position
254
- prefill_start_idx = 0 if cache_position is None else cache_position[0]
 
 
255
  if (
256
  audio_values is not None
257
  and audio_token_start_idx is not None
@@ -262,32 +365,37 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
262
  audio_token_start_idx - prefill_start_idx
263
  )
264
  model_input["audio_token_len"] = audio_token_len
 
 
265
 
266
  return model_input
267
 
 
 
 
 
 
 
 
 
268
  @classmethod
269
  def _create_audio_tower(
270
  cls, config: UltravoxConfig
271
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
272
- if config.audio_model_id is not None:
273
- if "whisper" in config.audio_model_id is not None:
274
- audio_tower = ModifiedWhisperEncoder.from_pretrained(
275
- config.audio_model_id
276
- )
277
- else:
278
- audio_tower = transformers.AutoModel.from_pretrained(
279
- config.audio_model_id
280
- )
281
- else:
282
- if "whisper" in config.audio_config._name_or_path:
283
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
 
 
 
284
  else:
285
- with transformers.modeling_utils.no_init_weights():
286
- # we only ever use from_config if the weights are retrained, hence initializing is not
287
- # required. This makes the model quite creation faster since init on CPU is quite slow.
288
- audio_tower = transformers.AutoModel.from_config(
289
- config.audio_config
290
- )
291
 
292
  if isinstance(
293
  audio_tower,
@@ -305,23 +413,22 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
305
  def _create_language_model(
306
  cls, config: UltravoxConfig
307
  ) -> transformers.LlamaForCausalLM:
308
- if config.text_model_id is not None:
309
- language_model = transformers.AutoModelForCausalLM.from_pretrained(
310
- config.text_model_id, attn_implementation=config._attn_implementation
 
 
 
 
311
  )
312
- else:
313
- with transformers.modeling_utils.no_init_weights():
314
- # we only ever use from_config if the weights are retrained, hence initializing is not
315
- # required. This makes the model quite creation faster since init on CPU is quite slow.
316
- language_model = transformers.AutoModelForCausalLM.from_config(
317
- config.text_config, attn_implementation=config._attn_implementation
318
- )
319
 
320
  language_model = apply_lora(language_model, config.text_model_lora_config)
321
  return language_model
322
 
323
- def _add_language_model_weights_to_keep(self):
324
- if self.config.text_model_id is not None:
 
 
325
  self.config.text_model_id = None
326
  self.keep_params.update(
327
  set(
@@ -332,8 +439,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
332
  )
333
  )
334
 
335
- def _add_audio_tower_weights_to_keep(self):
336
- if self.config.audio_model_id is not None:
 
337
  self.config.audio_model_id = None
338
  self.keep_params.update(
339
  set(
@@ -344,46 +452,44 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
344
  )
345
  )
346
 
347
- def merge_and_unload(self):
348
- if isinstance(self.language_model, peft.PeftModel):
349
- self.language_model = self.language_model.merge_and_unload()
350
- # no need to download base language model weights anymore, so we can remove the id
351
- self._add_language_model_weights_to_keep()
352
-
353
- if isinstance(self.audio_tower, peft.PeftModel):
354
- self.audio_tower = self.audio_tower.merge_and_unload()
355
- # no need to download base audio model weights anymore, so we can remove the id
356
- self._add_audio_tower_weights_to_keep()
357
-
358
  for param in ["text_model_lora_config", "audio_model_lora_config"]:
359
  if hasattr(self.config, param):
360
  delattr(self.config, param)
361
 
362
  def push_to_hub(self, *args, **kwargs):
363
  self.merge_and_unload()
364
- self.to(self.language_model.dtype)
365
  return super().push_to_hub(*args, **kwargs)
366
 
367
- def state_dict(self, *args, **kwargs):
368
- named_params = dict(self.named_parameters())
369
- state_dict = super().state_dict(*args, **kwargs)
 
 
 
 
 
 
 
 
 
370
 
371
  state_dict = {
372
  k: v
373
  for k, v in state_dict.items()
374
- if k in self.keep_params
375
- or (k in named_params and named_params[k].requires_grad)
376
  }
 
377
  return state_dict
378
 
379
- def load_state_dict(
380
- self,
381
- state_dict: Dict[str, Any],
382
- *args,
383
- **kwargs,
384
  ):
 
 
 
 
 
385
  self.keep_params.update(set(state_dict.keys()))
386
- return super().load_state_dict(state_dict, *args, **kwargs)
387
 
388
  def print_trainable_parameters(self):
389
  """
@@ -414,8 +520,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
414
  )
415
 
416
 
 
417
  def is_cache_empty(
418
- past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
419
  ) -> bool:
420
  """
421
  Check if the cache is empty.
@@ -427,16 +534,25 @@ def is_cache_empty(
427
  return past_key_values.get_seq_length() == 0
428
 
429
 
430
- def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
 
 
 
431
  """
432
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
433
  """
 
434
  lora_config = peft.LoraConfig(**lora_config or {})
435
 
436
  if lora_config.r == 0:
437
- # freeze the model entirely
438
- for param in model.parameters():
439
- param.requires_grad = False
 
 
 
 
 
440
  else:
441
  model = peft.get_peft_model(model, lora_config)
442
 
@@ -445,12 +561,8 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
445
 
446
  class StackAudioFrames(nn.Module):
447
  """
448
- Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
449
-
450
- The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
451
- NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
452
- we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
453
- In most cases this extra padding will get removed in the model's forward function so it has no effect.
454
  """
455
 
456
  def __init__(self, stack_factor: int = 8):
@@ -460,7 +572,7 @@ class StackAudioFrames(nn.Module):
460
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
461
  B, T, C = audio_embeds.shape
462
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
463
- audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
464
  B, T, C = audio_embeds.shape
465
  audio_embeds = audio_embeds.view(
466
  B, T // self.stack_factor, C * self.stack_factor
@@ -480,31 +592,67 @@ class SwiGLU(nn.Module):
480
  return F.silu(gate) * x
481
 
482
 
483
- class UltravoxProjector(nn.Sequential):
484
  def __init__(self, config: UltravoxConfig):
485
  super().__init__()
486
  self.hidden_dim = config.hidden_size
487
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
488
- dim = config.audio_config.hidden_size * config.stack_factor
489
- self.ln_pre = RMSNorm(dim, init=config.norm_init)
490
- self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
491
- dim = self.hidden_dim
492
  self.act = transformers.activations.get_activation(config.projector_act)
493
- dim = dim // 2 if config.projector_act == "swiglu" else dim
494
- self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
495
- self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
 
 
 
 
 
 
 
 
 
496
 
497
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  audio_features = self._pad_and_stack(audio_features)
499
  audio_features = self.ln_pre(audio_features)
 
500
  hidden_states = self.linear_1(audio_features)
 
501
  hidden_states = self.act(hidden_states)
 
 
502
  hidden_states = self.linear_2(hidden_states)
503
  hidden_states = self.ln_post(hidden_states)
504
  return hidden_states
505
 
506
 
507
- class ModifiedWhisperEncoder(whisper.WhisperEncoder):
 
 
508
  """
509
  Encoder portion of OpenAI's Whisper model.
510
 
@@ -518,21 +666,62 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
518
  """
519
 
520
  base_model_prefix = "model.encoder"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  def forward(
523
  self,
524
  input_features,
525
- attention_mask=None,
526
  head_mask=None,
527
  output_attentions=None,
528
  output_hidden_states=None,
529
  return_dict=None,
530
  ):
531
- expected_seq_length = (
532
- self.config.max_source_positions
533
- * self.conv1.stride[0]
534
- * self.conv2.stride[0]
535
- )
536
  if input_features.shape[-1] > expected_seq_length:
537
  raise ValueError(
538
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
@@ -565,6 +754,37 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
565
  encoder_states = () if output_hidden_states else None
566
  all_attentions = () if output_attentions else None
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  # check if head_mask has a correct number of layers specified if desired
569
  if head_mask is not None:
570
  assert head_mask.size()[0] == (
@@ -588,14 +808,14 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
588
  layer_outputs = self._gradient_checkpointing_func(
589
  encoder_layer.__call__,
590
  hidden_states,
591
- None,
592
  (head_mask[idx] if head_mask is not None else None),
593
  output_attentions,
594
  )
595
  else:
596
  layer_outputs = encoder_layer(
597
  hidden_states,
598
- None,
599
  layer_head_mask=(
600
  head_mask[idx] if head_mask is not None else None
601
  ),
@@ -630,4 +850,4 @@ UltravoxModel.register_for_auto_class()
630
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
631
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
632
 
633
- transformers.activations.ACT2FN["swiglu"] = SwiGLU
 
1
  import logging
2
+ import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
4
 
5
  import peft
6
  import torch
 
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
13
+ from transformers.generation.utils import GenerationMixin
14
  from transformers.models.whisper import modeling_whisper as whisper
15
 
16
  # We must use relative import in this directory to allow uploading to HF Hub
 
20
  from .ultravox_config import UltravoxConfig
21
 
22
 
23
+ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
24
  """
25
  The Ultravox model which consists of an audio encoder and a language model.
26
 
 
36
 
37
  config_class = UltravoxConfig
38
  config: UltravoxConfig # for type hinting
39
+ # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
40
+ _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
41
+ # Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
42
+ # see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
43
+ accepts_loss_kwargs = False
 
 
 
 
44
 
45
  def __init__(self, config: UltravoxConfig):
46
  super().__init__(config)
47
+ self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
48
 
49
  self.keep_params: Set[str] = set()
50
  self.vocab_size = config.vocab_size
51
 
52
  self.audio_tower = self._create_audio_tower(config)
53
+ self.audio_tower_context_length: Optional[int] = None
54
+ self.audio_tower_context_length = self.audio_tower.max_context_length
55
+
56
+ self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
59
+ if self.language_model._tied_weights_keys is not None:
60
+ self._tied_weights_keys = [
61
+ f"language_model.{k}" for k in self.language_model._tied_weights_keys
62
+ ]
63
+
64
+ # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
65
+ # FSDP throws an error if some of the layer types are not found in the model.
66
+ # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
67
+ self._no_split_modules = self.language_model._no_split_modules
68
+
69
  self.loss_config = LossConfig()
70
  self.post_init()
71
 
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
74
+ model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
75
+ model._load_child_model_weights(*args, **kwargs)
76
+ return model
77
+
78
+ def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
79
+ if (
80
+ self.config.text_model_id is not None
81
+ and self.language_model.device.type == "meta"
82
+ ):
83
+ # Load the language model weights
84
+ self.language_model = transformers.AutoModelForCausalLM.from_pretrained(
85
+ self.config.text_model_id,
86
+ torch_dtype=self.config.torch_dtype,
87
+ *args,
88
+ **kwargs,
89
+ )
90
+
91
+ if (
92
+ self.config.audio_model_id is not None
93
+ and self.audio_tower.device.type == "meta"
94
+ ):
95
+ # Load the audio tower weights
96
+ self.audio_tower = transformers.AutoModel.from_pretrained(
97
+ self.config.audio_model_id,
98
+ torch_dtype=self.config.torch_dtype,
99
+ *args,
100
+ **kwargs,
101
+ )
102
+
103
+ return self
104
+
105
  def get_input_embeddings(self):
106
  return self.language_model.get_input_embeddings()
107
 
 
148
  self.vocab_size = model_embeds.num_embeddings
149
  return model_embeds
150
 
151
+ def _get_prediction_mask(self, labels: Optional[torch.Tensor]) -> torch.Tensor:
152
+ """Get a boolean mask for positions where we want to compute KL divergence.
153
+
154
+ For each label position, we want the position before it since that's where
155
+ the model makes the prediction for that label.
156
+
157
+ Args:
158
+ labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
159
+ with -100 for masked positions and token ids for label positions
160
+
161
+ Returns:
162
+ Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
163
+ """
164
+ if labels is None:
165
+ raise ValueError("labels must be provided")
166
+ # Shift the label mask right by 1 along the sequence dimension
167
+ # This gives us positions where we make predictions for the next token
168
+ label_mask = labels != -100
169
+ pred_mask = torch.zeros_like(label_mask)
170
+ pred_mask[:, :-1] = label_mask[
171
+ :, 1:
172
+ ] # shift right by 1 along sequence dimension
173
+ return pred_mask
174
+
175
  def _compute_kl_loss(
176
  self,
177
  lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
 
196
  # compute the KL divergence loss between the two models
197
  kl_loss = F.kl_div(
198
  F.log_softmax(
199
+ lm_output.logits[self._get_prediction_mask(labels)]
200
+ / self.loss_config.kl_temperature,
201
  dim=-1,
202
  ),
203
  F.softmax(
204
+ alt_lm_output.logits[self._get_prediction_mask(alt_labels)]
205
  / self.loss_config.kl_temperature,
206
  dim=-1,
207
  ),
 
209
  )
210
  return {"loss": kl_loss}
211
 
212
+ def _audio_iter(
213
+ self, audio_batch_size: torch.Tensor
214
+ ) -> Generator[Tuple[int, int], None, None]:
215
+ """
216
+ Iterate over the audio batch size and yield the batch index and audio index of each audio item.
217
+
218
+ Args:
219
+ audio_batch_size: A tensor of shape (B,) where B is the batch size.
220
+
221
+ Returns:
222
+ A generator that yields a tuple of (start index, length) for each audio item.
223
+ """
224
+ audio_index = 0
225
+ for i_b, batch_count in enumerate(audio_batch_size):
226
+ for _ in range(batch_count):
227
+ yield i_b, audio_index
228
+ audio_index += 1
229
+
230
  def forward(
231
  self,
232
  input_ids: torch.Tensor,
 
235
  labels: Optional[torch.Tensor] = None,
236
  attention_mask: Optional[torch.Tensor] = None,
237
  audio_token_start_idx: Optional[torch.Tensor] = None,
238
+ audio_lens: Optional[torch.Tensor] = None,
239
  audio_token_len: Optional[torch.Tensor] = None,
240
+ audio_batch_size: Optional[torch.Tensor] = None,
241
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
242
  # the alt_* fields are needed for KL divergence loss
243
  alt_input_ids: Optional[torch.Tensor] = None,
 
268
  # B x T -> B x T x D
269
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
270
 
271
+ if audio_values is not None and len(audio_values) > 0:
272
  assert (
273
+ audio_token_start_idx is not None
274
+ and audio_token_len is not None
275
+ and audio_lens is not None
276
+ and audio_batch_size is not None
277
+ ), "audio_token_start_idx/audio_token_len/audio_lens must be provided if audio_values are provided."
278
  assert (
279
+ len(audio_token_start_idx)
280
+ == len(audio_token_len)
281
+ == len(audio_lens)
282
+ == len(audio_values)
283
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
284
+ assert len(audio_batch_size) == len(
285
+ inputs_embeds
286
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
287
+
288
+ # B x A/3200 x (D=max-audio-length-in-batch)
289
  audio_tower_output = self.audio_tower.forward(
290
+ audio_values.to(self.audio_tower.dtype),
291
+ audio_len=audio_lens,
292
  ).last_hidden_state
293
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
 
294
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
295
 
296
  # combine audio and text embeddings
297
+ for i_b, i_a in self._audio_iter(audio_batch_size):
298
+ start_idx = audio_token_start_idx[i_a]
299
+ token_len = audio_token_len[i_a]
300
+ item_embedding = audio_embeds[i_a][:token_len]
301
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
302
 
303
  lm_output = self.language_model.forward(
304
  inputs_embeds=inputs_embeds,
 
333
  audio_values: Optional[torch.FloatTensor] = None,
334
  audio_token_start_idx: Optional[torch.Tensor] = None,
335
  audio_token_len: Optional[torch.Tensor] = None,
336
+ audio_lens: Optional[torch.Tensor] = None,
337
+ audio_batch_size: Optional[torch.Tensor] = None,
338
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
339
  attention_mask: Optional[torch.Tensor] = None,
340
  inputs_embeds: Optional[torch.Tensor] = None,
 
352
 
353
  # include audio information in model_input only when it is needed during prefilling
354
  # audio_token_start_idx should always be relative to the current cache position
355
+ prefill_start_idx: int | torch.Tensor = (
356
+ 0 if cache_position is None else cache_position[0]
357
+ )
358
  if (
359
  audio_values is not None
360
  and audio_token_start_idx is not None
 
365
  audio_token_start_idx - prefill_start_idx
366
  )
367
  model_input["audio_token_len"] = audio_token_len
368
+ model_input["audio_batch_size"] = audio_batch_size
369
+ model_input["audio_lens"] = audio_lens
370
 
371
  return model_input
372
 
373
+ @classmethod
374
+ def _create_multi_modal_projector(
375
+ cls, config: UltravoxConfig
376
+ ) -> "UltravoxProjector":
377
+ projector = UltravoxProjector(config)
378
+ projector.to(config.torch_dtype)
379
+ return projector
380
+
381
  @classmethod
382
  def _create_audio_tower(
383
  cls, config: UltravoxConfig
384
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
385
+ with transformers.modeling_utils.no_init_weights():
386
+ # we only ever use from_config if the weights are retrained, hence initializing is not
387
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
388
+ if "whisper" in config.audio_config._name_or_path.lower():
 
 
 
 
 
 
 
389
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
390
+ audio_tower.init_latency_mask(
391
+ config.audio_latency_block_size, dtype=config.torch_dtype
392
+ )
393
  else:
394
+ assert config.audio_latency_block_size in (
395
+ None,
396
+ 0,
397
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
398
+ audio_tower = transformers.AutoModel.from_config(config.audio_config)
 
399
 
400
  if isinstance(
401
  audio_tower,
 
413
  def _create_language_model(
414
  cls, config: UltravoxConfig
415
  ) -> transformers.LlamaForCausalLM:
416
+ with transformers.modeling_utils.no_init_weights():
417
+ # we only ever use from_config if the weights are retrained, hence initializing is not
418
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
419
+ language_model = transformers.AutoModelForCausalLM.from_config(
420
+ config.text_config,
421
+ attn_implementation=config.text_config._attn_implementation,
422
+ torch_dtype=config.torch_dtype,
423
  )
 
 
 
 
 
 
 
424
 
425
  language_model = apply_lora(language_model, config.text_model_lora_config)
426
  return language_model
427
 
428
+ def merge_and_unload(self):
429
+ if isinstance(self.language_model, peft.PeftModel):
430
+ self.language_model = self.language_model.merge_and_unload()
431
+ # no need to download base language model weights anymore, so we can remove the id
432
  self.config.text_model_id = None
433
  self.keep_params.update(
434
  set(
 
439
  )
440
  )
441
 
442
+ if isinstance(self.audio_tower, peft.PeftModel):
443
+ self.audio_tower = self.audio_tower.merge_and_unload()
444
+ # no need to download base audio model weights anymore, so we can remove the id
445
  self.config.audio_model_id = None
446
  self.keep_params.update(
447
  set(
 
452
  )
453
  )
454
 
 
 
 
 
 
 
 
 
 
 
 
455
  for param in ["text_model_lora_config", "audio_model_lora_config"]:
456
  if hasattr(self.config, param):
457
  delattr(self.config, param)
458
 
459
  def push_to_hub(self, *args, **kwargs):
460
  self.merge_and_unload()
 
461
  return super().push_to_hub(*args, **kwargs)
462
 
463
+ def diff_state_dict(
464
+ self, state_dict: Optional[Dict[str, Any]] = None
465
+ ) -> Dict[str, Any]:
466
+ if state_dict is None:
467
+ state_dict = super().state_dict()
468
+
469
+ trainable_params = {k for k, v in self.named_parameters() if v.requires_grad}
470
+ # normalize the keys to match the original model
471
+ # Example: audio_tower.base_model.model.layers.0._fsdp_wrapped_module.self_attn.k_proj.lora_B.default.weight
472
+ trainable_params = {
473
+ k.replace("_fsdp_wrapped_module.", "") for k in trainable_params
474
+ }
475
 
476
  state_dict = {
477
  k: v
478
  for k, v in state_dict.items()
479
+ if k in self.keep_params or k in trainable_params
 
480
  }
481
+
482
  return state_dict
483
 
484
+ def save_pretrained(
485
+ self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
 
 
 
486
  ):
487
+ state_dict = self.diff_state_dict(state_dict)
488
+
489
+ super().save_pretrained(*args, state_dict=state_dict, **kwargs)
490
+
491
+ def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
492
  self.keep_params.update(set(state_dict.keys()))
 
493
 
494
  def print_trainable_parameters(self):
495
  """
 
520
  )
521
 
522
 
523
+ # TODO: refactor common parts to a shared module
524
  def is_cache_empty(
525
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
526
  ) -> bool:
527
  """
528
  Check if the cache is empty.
 
534
  return past_key_values.get_seq_length() == 0
535
 
536
 
537
+ T = TypeVar("T", bound=torch.nn.Module)
538
+
539
+
540
+ def apply_lora(model: T, lora_config: dict) -> T:
541
  """
542
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
543
  """
544
+ unfreeze_layers = lora_config.pop("unfreeze_layers", None)
545
  lora_config = peft.LoraConfig(**lora_config or {})
546
 
547
  if lora_config.r == 0:
548
+ # freeze the model entirely, except for the specified layers
549
+ for name, param in model.named_parameters():
550
+ if not unfreeze_layers or not any(
551
+ re.match(layer, name) for layer in unfreeze_layers
552
+ ):
553
+ param.requires_grad = False
554
+ else:
555
+ logging.info(f"Unfreezing layer: {name} with #{param.numel()} params")
556
  else:
557
  model = peft.get_peft_model(model, lora_config)
558
 
 
561
 
562
  class StackAudioFrames(nn.Module):
563
  """
564
+ Stack the audio embedding frames to reduce the sequence length by a factor
565
+ of `stack_factor`.
 
 
 
 
566
  """
567
 
568
  def __init__(self, stack_factor: int = 8):
 
572
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
573
  B, T, C = audio_embeds.shape
574
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
575
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
576
  B, T, C = audio_embeds.shape
577
  audio_embeds = audio_embeds.view(
578
  B, T // self.stack_factor, C * self.stack_factor
 
592
  return F.silu(gate) * x
593
 
594
 
595
+ class UltravoxProjector(nn.Module):
596
  def __init__(self, config: UltravoxConfig):
597
  super().__init__()
598
  self.hidden_dim = config.hidden_size
599
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
600
+ dim_in = config.audio_config.hidden_size * config.stack_factor
601
+ self.ln_pre = RMSNorm(dim_in, init=config.norm_init)
602
+ self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
603
+ dim_mid = self.hidden_dim
604
  self.act = transformers.activations.get_activation(config.projector_act)
605
+ dim_mid = dim_mid // 2 if config.projector_act == "swiglu" else dim_mid
606
+ dim_out = config.text_config.hidden_size
607
+ self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
608
+
609
+ # Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
610
+ # while v0.5.0 and above uses layer_norm after the first linear layer.
611
+ if config.projector_ln_mid:
612
+ self.ln_mid: nn.Module = RMSNorm(dim_mid, init=config.norm_init)
613
+ self.ln_post: nn.Module = nn.Identity()
614
+ else:
615
+ self.ln_mid = nn.Identity()
616
+ self.ln_post = RMSNorm(dim_out, init=config.norm_init)
617
 
618
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
619
+ """
620
+ Takes in audio features from the audio tower and projects them to the text model's embedding space.
621
+ It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
622
+ If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
623
+
624
+ Input shape:
625
+ audio_features: B, T*S, C
626
+ Output shape:
627
+ hidden_states: B, T, D
628
+ Where:
629
+ B: batch size
630
+ F: number of frames in the audio tower
631
+ T: number of output embeddings
632
+ T = ceil(F / S)
633
+ S: stack factor
634
+ C: number of channels out of the encoder (aka audio tower)
635
+ H: hidden size of the projector (config.hidden_size)
636
+ D: dimension of the text model (config.text_config.hidden_size)
637
+
638
+ """
639
+ # B, F, C -> B, T, C*S
640
  audio_features = self._pad_and_stack(audio_features)
641
  audio_features = self.ln_pre(audio_features)
642
+ # B, T, C*S -> B, T, H
643
  hidden_states = self.linear_1(audio_features)
644
+ # B, T, H -> B, T, H/2 (assuming swiglu)
645
  hidden_states = self.act(hidden_states)
646
+ hidden_states = self.ln_mid(hidden_states)
647
+ # B, T, H/2 -> B, T, D
648
  hidden_states = self.linear_2(hidden_states)
649
  hidden_states = self.ln_post(hidden_states)
650
  return hidden_states
651
 
652
 
653
+ class ModifiedWhisperEncoder(
654
+ whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin
655
+ ):
656
  """
657
  Encoder portion of OpenAI's Whisper model.
658
 
 
666
  """
667
 
668
  base_model_prefix = "model.encoder"
669
+ _no_split_modules = ["WhisperEncoderLayer"]
670
+ _keys_to_ignore_on_load_unexpected = ["model.decoder.*"]
671
+
672
+ def __init__(self, config: transformers.WhisperConfig):
673
+ super().__init__(config)
674
+ self.config.is_decoder = False
675
+
676
+ @property
677
+ def max_context_length(self):
678
+ return (
679
+ self.config.max_source_positions
680
+ * self.conv1.stride[0]
681
+ * self.conv2.stride[0]
682
+ )
683
+
684
+ def init_latency_mask(
685
+ self, audio_latency_block_size: int | None, dtype: torch.dtype
686
+ ):
687
+ if audio_latency_block_size is None:
688
+ self.audio_streaming_mask = None
689
+ return
690
+
691
+ # Use max_context_length directly in the calculation
692
+ max_seqlen = self.max_context_length
693
+ assert (
694
+ max_seqlen > 0
695
+ ), f"maximum sequence length must be positive, got {max_seqlen}"
696
+ assert (
697
+ max_seqlen % audio_latency_block_size == 0
698
+ ), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly."
699
+ # Given the block size, we calculate number of blocks.
700
+ audio_latency_nblocks = max_seqlen // audio_latency_block_size
701
+ audio_streaming_mask = (
702
+ torch.tril(
703
+ torch.ones(audio_latency_nblocks, audio_latency_nblocks),
704
+ diagonal=0,
705
+ )
706
+ .repeat_interleave(audio_latency_block_size, dim=0)
707
+ .repeat_interleave(audio_latency_block_size, dim=1)
708
+ )
709
+ audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min
710
+ audio_streaming_mask = audio_streaming_mask[None, None, :, :]
711
+ self.register_buffer(
712
+ "audio_streaming_mask", audio_streaming_mask, persistent=False
713
+ )
714
 
715
  def forward(
716
  self,
717
  input_features,
718
+ audio_len=None,
719
  head_mask=None,
720
  output_attentions=None,
721
  output_hidden_states=None,
722
  return_dict=None,
723
  ):
724
+ expected_seq_length = self.max_context_length
 
 
 
 
725
  if input_features.shape[-1] > expected_seq_length:
726
  raise ValueError(
727
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
 
754
  encoder_states = () if output_hidden_states else None
755
  all_attentions = () if output_attentions else None
756
 
757
+ # Create attention mask based on audio lengths to mask out padding tokens
758
+ # For each sample in batch:
759
+ # - Convert raw audio length to feature length after convolutions
760
+ # - Create boolean mask that is True for valid positions and False for padding
761
+ # - Convert to extended attention mask format expected by transformer layers
762
+ # (1.0 for positions to attend to, large negative for positions to ignore)
763
+ # This masking ensures consistent behavior between training and inference
764
+ # by preventing the model from attending to padding tokens in both cases
765
+ attention_mask = None
766
+ if audio_len != None:
767
+ audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
768
+ max_seq_len = hidden_states.shape[1]
769
+ attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
770
+ None, :
771
+ ].lt(audio_feature_len.view(-1, 1))
772
+ attention_mask = self.get_extended_attention_mask(
773
+ attention_mask,
774
+ None,
775
+ dtype=hidden_states.dtype,
776
+ )
777
+
778
+ if self.audio_streaming_mask is not None:
779
+ seqlen = hidden_states.size(-2)
780
+ if attention_mask is not None:
781
+ attention_mask = torch.minimum(
782
+ self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask
783
+ ) # merge
784
+ else:
785
+ attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen]
786
+ attention_mask = attention_mask.to(hidden_states.dtype)
787
+
788
  # check if head_mask has a correct number of layers specified if desired
789
  if head_mask is not None:
790
  assert head_mask.size()[0] == (
 
808
  layer_outputs = self._gradient_checkpointing_func(
809
  encoder_layer.__call__,
810
  hidden_states,
811
+ attention_mask,
812
  (head_mask[idx] if head_mask is not None else None),
813
  output_attentions,
814
  )
815
  else:
816
  layer_outputs = encoder_layer(
817
  hidden_states,
818
+ attention_mask,
819
  layer_head_mask=(
820
  head_mask[idx] if head_mask is not None else None
821
  ),
 
850
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
851
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
852
 
853
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU