Update modeling_conformer.py
Browse files- modeling_conformer.py +5 -5
modeling_conformer.py
CHANGED
|
@@ -95,7 +95,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 95 |
tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
|
| 96 |
starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
|
| 97 |
lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
|
| 98 |
-
hx = torch.zeros(
|
| 99 |
cx = torch.zeros_like(hx)
|
| 100 |
last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
|
| 101 |
|
|
@@ -173,11 +173,11 @@ encoder.pre_encode.conv_module.0,feature_extractor.conv_layers.0.conv
|
|
| 173 |
encoder.pre_encode.out,feature_projection.projection
|
| 174 |
"""
|
| 175 |
if not model.config.multilingual:
|
| 176 |
-
changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n/2)}.conv"
|
| 177 |
-
changes += f"lang_joint_net.{model.config.language},joint"
|
| 178 |
else:
|
| 179 |
-
changes += "encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}"
|
| 180 |
-
changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}"
|
| 181 |
state_dict = state_bridge(state_dict, changes)
|
| 182 |
if not model.config.multilingual:
|
| 183 |
state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
|
|
|
|
| 95 |
tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
|
| 96 |
starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
|
| 97 |
lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
|
| 98 |
+
hx = torch.zeros(self.config.lstm_layer, B, H, dtype=enc_out.dtype, device=enc_out.device)
|
| 99 |
cx = torch.zeros_like(hx)
|
| 100 |
last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
|
| 101 |
|
|
|
|
| 173 |
encoder.pre_encode.out,feature_projection.projection
|
| 174 |
"""
|
| 175 |
if not model.config.multilingual:
|
| 176 |
+
changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n/2)}.conv\n"
|
| 177 |
+
changes += f"lang_joint_net.{model.config.language},joint\n"
|
| 178 |
else:
|
| 179 |
+
changes += "encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}\n"
|
| 180 |
+
changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}\n"
|
| 181 |
state_dict = state_bridge(state_dict, changes)
|
| 182 |
if not model.config.multilingual:
|
| 183 |
state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
|