shethjenil commited on
Commit
e63e063
·
verified ·
1 Parent(s): d87b0f5

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +5 -5
modeling_conformer.py CHANGED
@@ -95,7 +95,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
95
  tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
96
  starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
97
  lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
98
- hx = torch.zeros(1, B, H, dtype=enc_out.dtype, device=enc_out.device)
99
  cx = torch.zeros_like(hx)
100
  last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
101
 
@@ -173,11 +173,11 @@ encoder.pre_encode.conv_module.0,feature_extractor.conv_layers.0.conv
173
  encoder.pre_encode.out,feature_projection.projection
174
  """
175
  if not model.config.multilingual:
176
- changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n/2)}.conv"
177
- changes += f"lang_joint_net.{model.config.language},joint"
178
  else:
179
- changes += "encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}"
180
- changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}"
181
  state_dict = state_bridge(state_dict, changes)
182
  if not model.config.multilingual:
183
  state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
 
95
  tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
96
  starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
97
  lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
98
+ hx = torch.zeros(self.config.lstm_layer, B, H, dtype=enc_out.dtype, device=enc_out.device)
99
  cx = torch.zeros_like(hx)
100
  last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
101
 
 
173
  encoder.pre_encode.out,feature_projection.projection
174
  """
175
  if not model.config.multilingual:
176
+ changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n/2)}.conv\n"
177
+ changes += f"lang_joint_net.{model.config.language},joint\n"
178
  else:
179
+ changes += "encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}\n"
180
+ changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}\n"
181
  state_dict = state_bridge(state_dict, changes)
182
  if not model.config.multilingual:
183
  state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}