Update modeling_conformer.py
Browse files- modeling_conformer.py +38 -63
modeling_conformer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from huggingface_hub import hf_hub_download
|
| 2 |
from torch import nn
|
| 3 |
-
from transformers import Wav2Vec2ConformerModel
|
| 4 |
from safetensors.torch import load_file
|
| 5 |
from torch_state_bridge import state_bridge
|
| 6 |
import torch
|
|
@@ -30,7 +30,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 30 |
self.joint = nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1)
|
| 31 |
self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
|
| 32 |
self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
|
| 33 |
-
self.act = nn.ReLU(
|
| 34 |
self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
|
| 35 |
self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
|
| 36 |
self.register_buffer("mel_fb",torch.tensor(librosa.filters.mel(sr=config.sampling_rate,n_fft=512,n_mels=80)))
|
|
@@ -51,7 +51,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 51 |
self.encoder.layer_norm = nn.Identity()
|
| 52 |
if config.multilingual:
|
| 53 |
self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) for l in config.languages})
|
| 54 |
-
self.
|
| 55 |
self.denorm = (2 ** config.num_feat_extract_layers) * self.spec.hop_length / config.sampling_rate
|
| 56 |
self.scaler = config.hidden_size ** (1/2)
|
| 57 |
return super().init_weights()
|
|
@@ -70,7 +70,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 70 |
def preprocessing(self, x):
|
| 71 |
x, l = x
|
| 72 |
l = (l // self.spec.hop_length + 1).long()
|
| 73 |
-
x = torch.cat((x[:, :1], x[:, 1:] - self.preemph * x[:, :-1]), 1)
|
| 74 |
x = (self.mel_fb @ self.spec(x) + self.eps).log()
|
| 75 |
T = x.size(-1)
|
| 76 |
m = torch.arange(T, device=x.device)[None] >= l[:, None]
|
|
@@ -80,76 +80,51 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 80 |
σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
|
| 81 |
x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
|
| 82 |
self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
|
| 83 |
-
return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
|
| 84 |
|
| 85 |
def forward(self, input_values):
|
| 86 |
-
return self.
|
| 87 |
-
|
| 88 |
-
def _greedy_decode(self, enc_out: torch.Tensor):
|
| 89 |
-
B, T, _ = enc_out.size()
|
| 90 |
-
device = enc_out.device
|
| 91 |
-
enc_proj = self.enc(enc_out)
|
| 92 |
-
max_symbols = self.config.max_symbols_per_step
|
| 93 |
-
max_len = T * max_symbols
|
| 94 |
-
token_buffer = torch.full(
|
| 95 |
-
(B, max_len),
|
| 96 |
-
-1,
|
| 97 |
-
dtype=torch.long,
|
| 98 |
-
device=device
|
| 99 |
-
)
|
| 100 |
-
start_buffer = torch.zeros(
|
| 101 |
-
(B, max_len),
|
| 102 |
-
device=device
|
| 103 |
-
)
|
| 104 |
-
lengths = torch.zeros(B, dtype=torch.long, device=device)
|
| 105 |
-
last = torch.full(
|
| 106 |
-
(B, 1),
|
| 107 |
-
self.config.blank_id,
|
| 108 |
-
dtype=torch.long,
|
| 109 |
-
device=device
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
h = None
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
break
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
start_buffer[emit_mask, pos] = t * self.denorm
|
| 132 |
|
| 133 |
-
|
|
|
|
| 134 |
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
torch.where(keep_mask, h[0], h2[0]),
|
| 143 |
-
torch.where(keep_mask, h[1], h2[1]),
|
| 144 |
-
)
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
starts.append(start_buffer[b, :L])
|
| 153 |
|
| 154 |
return tokens, starts
|
| 155 |
|
|
|
|
| 1 |
from huggingface_hub import hf_hub_download
|
| 2 |
from torch import nn
|
| 3 |
+
from transformers import Wav2Vec2ConformerModel
|
| 4 |
from safetensors.torch import load_file
|
| 5 |
from torch_state_bridge import state_bridge
|
| 6 |
import torch
|
|
|
|
| 30 |
self.joint = nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1)
|
| 31 |
self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
|
| 32 |
self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
|
| 33 |
+
self.act = nn.ReLU()
|
| 34 |
self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
|
| 35 |
self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
|
| 36 |
self.register_buffer("mel_fb",torch.tensor(librosa.filters.mel(sr=config.sampling_rate,n_fft=512,n_mels=80)))
|
|
|
|
| 51 |
self.encoder.layer_norm = nn.Identity()
|
| 52 |
if config.multilingual:
|
| 53 |
self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) for l in config.languages})
|
| 54 |
+
self.eps = 2**-24
|
| 55 |
self.denorm = (2 ** config.num_feat_extract_layers) * self.spec.hop_length / config.sampling_rate
|
| 56 |
self.scaler = config.hidden_size ** (1/2)
|
| 57 |
return super().init_weights()
|
|
|
|
| 70 |
def preprocessing(self, x):
|
| 71 |
x, l = x
|
| 72 |
l = (l // self.spec.hop_length + 1).long()
|
| 73 |
+
x = torch.cat((x[:, :1], x[:, 1:] - self.config.preemph * x[:, :-1]), 1)
|
| 74 |
x = (self.mel_fb @ self.spec(x) + self.eps).log()
|
| 75 |
T = x.size(-1)
|
| 76 |
m = torch.arange(T, device=x.device)[None] >= l[:, None]
|
|
|
|
| 80 |
σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
|
| 81 |
x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
|
| 82 |
self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
|
| 83 |
+
return F.pad(x, (0, (-T) % self.config.pad_to)).transpose(1, 2)
|
| 84 |
|
| 85 |
def forward(self, input_values):
|
| 86 |
+
return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
def postprocessing(self, enc_out):
|
| 89 |
+
B, T, _ = enc_out.shape
|
| 90 |
+
H = self.lstm.hidden_size
|
| 91 |
+
blank = self.config.blank_id
|
| 92 |
+
pad = self.config.pad_id
|
| 93 |
+
max_len = T * self.config.max_symbols_per_step
|
| 94 |
|
| 95 |
+
tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
|
| 96 |
+
starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
|
| 97 |
+
lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
|
| 98 |
+
hx = torch.zeros(1, B, H, dtype=enc_out.dtype, device=enc_out.device)
|
| 99 |
+
cx = torch.zeros_like(hx)
|
| 100 |
+
last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
|
| 101 |
|
| 102 |
+
enc_proj = self.enc(enc_out) # (B, T, D)
|
|
|
|
| 103 |
|
| 104 |
+
for t in range(T):
|
| 105 |
+
e = enc_proj[:, t:t+1]
|
| 106 |
+
t_sec = torch.full((B, 1), t * self.denorm, dtype=enc_out.dtype, device=enc_out.device)
|
|
|
|
| 107 |
|
| 108 |
+
for _ in range(self.config.max_symbols_per_step):
|
| 109 |
+
hx_prev, cx_prev = hx, cx
|
| 110 |
|
| 111 |
+
p, (hx, cx) = self.lstm(self.embed(last), (hx, cx))
|
| 112 |
+
n = self.joint(self.act(e + self.pred(p))).squeeze(1).argmax(-1) # (B,)
|
| 113 |
+
emitted = n.ne(blank)
|
| 114 |
|
| 115 |
+
# revert hidden for blanks
|
| 116 |
+
mask = emitted.view(1, B, 1)
|
| 117 |
+
hx = torch.where(mask, hx, hx_prev)
|
| 118 |
+
cx = torch.where(mask, cx, cx_prev)
|
| 119 |
+
last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
pos = lengths.unsqueeze(1).clamp(max=max_len - 1)
|
| 122 |
+
fill_t = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), torch.full_like(n.unsqueeze(1), pad))
|
| 123 |
+
fill_s = torch.where(emitted.unsqueeze(1), t_sec, torch.full_like(t_sec, -1.0))
|
| 124 |
|
| 125 |
+
tokens = tokens.scatter(1, pos, fill_t)
|
| 126 |
+
starts = starts.scatter(1, pos, fill_s)
|
| 127 |
+
lengths = lengths + emitted.long()
|
|
|
|
| 128 |
|
| 129 |
return tokens, starts
|
| 130 |
|