trfms-fix
#9
by eustlb HF Staff - opened
- ultravox_model.py +16 -12
ultravox_model.py
CHANGED
|
@@ -426,12 +426,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
| 426 |
# We probably don't want to pass tp_plan or device_map to the audio tower
|
| 427 |
# But potentially other kwargs can be passed in. TODO
|
| 428 |
kwargs = {"torch_dtype": config.torch_dtype}
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
| 435 |
if "whisper" in config.audio_model_id.lower():
|
| 436 |
audio_tower = ModifiedWhisperEncoder.from_pretrained(
|
| 437 |
config.audio_model_id, **kwargs
|
|
@@ -482,12 +484,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
| 482 |
def _create_language_model(
|
| 483 |
cls, config: UltravoxConfig
|
| 484 |
) -> transformers.LlamaForCausalLM:
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
|
|
|
|
|
|
| 491 |
language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 492 |
config.text_model_id,
|
| 493 |
**{
|
|
|
|
| 426 |
# We probably don't want to pass tp_plan or device_map to the audio tower
|
| 427 |
# But potentially other kwargs can be passed in. TODO
|
| 428 |
kwargs = {"torch_dtype": config.torch_dtype}
|
| 429 |
+
if hasattr(transformers.modeling_utils, "_init_weights"):
|
| 430 |
+
# v4 path
|
| 431 |
+
is_init = transformers.modeling_utils._init_weights
|
| 432 |
+
else:
|
| 433 |
+
# v5 path
|
| 434 |
+
_default_device = getattr(torch, "get_default_device", lambda: None)()
|
| 435 |
+
is_init = _default_device is None or _default_device.type != "meta"
|
| 436 |
+
if is_init and config.audio_model_id is not None:
|
| 437 |
if "whisper" in config.audio_model_id.lower():
|
| 438 |
audio_tower = ModifiedWhisperEncoder.from_pretrained(
|
| 439 |
config.audio_model_id, **kwargs
|
|
|
|
| 484 |
def _create_language_model(
|
| 485 |
cls, config: UltravoxConfig
|
| 486 |
) -> transformers.LlamaForCausalLM:
|
| 487 |
+
if hasattr(transformers.modeling_utils, "_init_weights"):
|
| 488 |
+
# v4 path
|
| 489 |
+
is_init = transformers.modeling_utils._init_weights
|
| 490 |
+
else:
|
| 491 |
+
# v5 path
|
| 492 |
+
_default_device = getattr(torch, "get_default_device", lambda: None)()
|
| 493 |
+
is_init = _default_device is None or _default_device.type != "meta"
|
| 494 |
+
if is_init and config.text_model_id is not None:
|
| 495 |
language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 496 |
config.text_model_id,
|
| 497 |
**{
|