Update modeling_ernie4_5_moe.py
Browse files- modeling_ernie4_5_moe.py +5 -1
modeling_ernie4_5_moe.py
CHANGED
|
@@ -1113,7 +1113,11 @@ class Ernie4_5_Model(Ernie4_5_PretrainedModel):
|
|
| 1113 |
past_key_values = DynamicCache()
|
| 1114 |
|
| 1115 |
if inputs_embeds is None:
|
| 1116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1117 |
|
| 1118 |
inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
|
| 1119 |
|
|
|
|
| 1113 |
past_key_values = DynamicCache()
|
| 1114 |
|
| 1115 |
if inputs_embeds is None:
|
| 1116 |
+
# Account for CPU offloaded embed_tokens
|
| 1117 |
+
embed_device = self.embed_tokens.weight.device
|
| 1118 |
+
inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking = True)).to(input_ids.device)
|
| 1119 |
+
if not self.training:
|
| 1120 |
+
inputs_embeds.requires_grad_(False)
|
| 1121 |
|
| 1122 |
inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
|
| 1123 |
|