Files changed (1) hide show
  1. ultravox_model.py +16 -12
ultravox_model.py CHANGED
@@ -426,12 +426,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
426
  # We probably don't want to pass tp_plan or device_map to the audio tower
427
  # But potentially other kwargs can be passed in. TODO
428
  kwargs = {"torch_dtype": config.torch_dtype}
429
- _default_device = getattr(torch, "get_default_device", lambda: None)()
430
- _is_init = _default_device is None or _default_device.type != "meta"
431
- if (
432
- _is_init
433
- and config.audio_model_id is not None
434
- ):
 
 
435
  if "whisper" in config.audio_model_id.lower():
436
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
437
  config.audio_model_id, **kwargs
@@ -482,12 +484,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
482
  def _create_language_model(
483
  cls, config: UltravoxConfig
484
  ) -> transformers.LlamaForCausalLM:
485
- _default_device = getattr(torch, "get_default_device", lambda: None)()
486
- _is_init = _default_device is None or _default_device.type != "meta"
487
- if (
488
- _is_init
489
- and config.text_model_id is not None
490
- ):
 
 
491
  language_model = transformers.AutoModelForCausalLM.from_pretrained(
492
  config.text_model_id,
493
  **{
 
426
  # We probably don't want to pass tp_plan or device_map to the audio tower
427
  # But potentially other kwargs can be passed in. TODO
428
  kwargs = {"torch_dtype": config.torch_dtype}
429
+ if hasattr(transformers.modeling_utils, "_init_weights"):
430
+ # v4 path
431
+ is_init = transformers.modeling_utils._init_weights
432
+ else:
433
+ # v5 path
434
+ _default_device = getattr(torch, "get_default_device", lambda: None)()
435
+ is_init = _default_device is None or _default_device.type != "meta"
436
+ if is_init and config.audio_model_id is not None:
437
  if "whisper" in config.audio_model_id.lower():
438
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
439
  config.audio_model_id, **kwargs
 
484
  def _create_language_model(
485
  cls, config: UltravoxConfig
486
  ) -> transformers.LlamaForCausalLM:
487
+ if hasattr(transformers.modeling_utils, "_init_weights"):
488
+ # v4 path
489
+ is_init = transformers.modeling_utils._init_weights
490
+ else:
491
+ # v5 path
492
+ _default_device = getattr(torch, "get_default_device", lambda: None)()
493
+ is_init = _default_device is None or _default_device.type != "meta"
494
+ if is_init and config.text_model_id is not None:
495
  language_model = transformers.AutoModelForCausalLM.from_pretrained(
496
  config.text_model_id,
497
  **{