danielhanchen commited on
Commit
d57bbdf
·
verified ·
1 Parent(s): bd3c174

Update modeling_ernie4_5_moe.py

Browse files
Files changed (1) hide show
  1. modeling_ernie4_5_moe.py +5 -1
modeling_ernie4_5_moe.py CHANGED
@@ -1113,7 +1113,11 @@ class Ernie4_5_Model(Ernie4_5_PretrainedModel):
1113
  past_key_values = DynamicCache()
1114
 
1115
  if inputs_embeds is None:
1116
- inputs_embeds = self.embed_tokens(input_ids)
 
 
 
 
1117
 
1118
  inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
1119
 
 
1113
  past_key_values = DynamicCache()
1114
 
1115
  if inputs_embeds is None:
1116
+ # Account for CPU offloaded embed_tokens
1117
+ embed_device = self.embed_tokens.weight.device
1118
+ inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking = True)).to(input_ids.device)
1119
+ if not self.training:
1120
+ inputs_embeds.requires_grad_(False)
1121
 
1122
  inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
1123