shethjenil commited on
Commit
f073b16
·
verified ·
1 Parent(s): 09c536a

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +38 -63
modeling_conformer.py CHANGED
@@ -1,6 +1,6 @@
1
  from huggingface_hub import hf_hub_download
2
  from torch import nn
3
- from transformers import Wav2Vec2ConformerModel , Wav2Vec2CTCTokenizer
4
  from safetensors.torch import load_file
5
  from torch_state_bridge import state_bridge
6
  import torch
@@ -30,7 +30,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
30
  self.joint = nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1)
31
  self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
32
  self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
33
- self.act = nn.ReLU(inplace=True)
34
  self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
35
  self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
36
  self.register_buffer("mel_fb",torch.tensor(librosa.filters.mel(sr=config.sampling_rate,n_fft=512,n_mels=80)))
@@ -51,7 +51,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
51
  self.encoder.layer_norm = nn.Identity()
52
  if config.multilingual:
53
  self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) for l in config.languages})
54
- self.preemph, self.eps, self.pad_to = 0.97, 2**-24, 16
55
  self.denorm = (2 ** config.num_feat_extract_layers) * self.spec.hop_length / config.sampling_rate
56
  self.scaler = config.hidden_size ** (1/2)
57
  return super().init_weights()
@@ -70,7 +70,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
70
  def preprocessing(self, x):
71
  x, l = x
72
  l = (l // self.spec.hop_length + 1).long()
73
- x = torch.cat((x[:, :1], x[:, 1:] - self.preemph * x[:, :-1]), 1)
74
  x = (self.mel_fb @ self.spec(x) + self.eps).log()
75
  T = x.size(-1)
76
  m = torch.arange(T, device=x.device)[None] >= l[:, None]
@@ -80,76 +80,51 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
80
  σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
81
  x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
82
  self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
83
- return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
84
 
85
  def forward(self, input_values):
86
- return self._greedy_decode(super().forward(self.preprocessing(input_values)).last_hidden_state)
87
-
88
- def _greedy_decode(self, enc_out: torch.Tensor):
89
- B, T, _ = enc_out.size()
90
- device = enc_out.device
91
- enc_proj = self.enc(enc_out)
92
- max_symbols = self.config.max_symbols_per_step
93
- max_len = T * max_symbols
94
- token_buffer = torch.full(
95
- (B, max_len),
96
- -1,
97
- dtype=torch.long,
98
- device=device
99
- )
100
- start_buffer = torch.zeros(
101
- (B, max_len),
102
- device=device
103
- )
104
- lengths = torch.zeros(B, dtype=torch.long, device=device)
105
- last = torch.full(
106
- (B, 1),
107
- self.config.blank_id,
108
- dtype=torch.long,
109
- device=device
110
- )
111
-
112
- h = None
113
 
114
- for t in range(T):
115
- e = enc_proj[:, t:t+1]
116
-
117
- for _ in range(max_symbols):
118
- p, h2 = self.lstm(self.embed(last), h)
119
- joint = self.joint(self.act(e + self.pred(p))).squeeze(1)
120
 
121
- n = joint.argmax(-1)
122
- blank = n.eq(self.config.blank_id)
123
- emit_mask = ~blank
 
 
 
124
 
125
- if not emit_mask.any():
126
- break
127
 
128
- pos = lengths[emit_mask]
129
-
130
- token_buffer[emit_mask, pos] = n[emit_mask]
131
- start_buffer[emit_mask, pos] = t * self.denorm
132
 
133
- lengths[emit_mask] += 1
 
134
 
135
- last = torch.where(emit_mask[:, None], n[:, None], last)
 
 
136
 
137
- if h is None:
138
- h = h2
139
- else:
140
- keep_mask = blank.view(1, -1, 1)
141
- h = (
142
- torch.where(keep_mask, h[0], h2[0]),
143
- torch.where(keep_mask, h[1], h2[1]),
144
- )
145
 
146
- tokens = []
147
- starts = []
 
148
 
149
- for b in range(B):
150
- L = lengths[b]
151
- tokens.append(token_buffer[b, :L])
152
- starts.append(start_buffer[b, :L])
153
 
154
  return tokens, starts
155
 
 
1
  from huggingface_hub import hf_hub_download
2
  from torch import nn
3
+ from transformers import Wav2Vec2ConformerModel
4
  from safetensors.torch import load_file
5
  from torch_state_bridge import state_bridge
6
  import torch
 
30
  self.joint = nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1)
31
  self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
32
  self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
33
+ self.act = nn.ReLU()
34
  self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
35
  self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
36
  self.register_buffer("mel_fb",torch.tensor(librosa.filters.mel(sr=config.sampling_rate,n_fft=512,n_mels=80)))
 
51
  self.encoder.layer_norm = nn.Identity()
52
  if config.multilingual:
53
  self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) for l in config.languages})
54
+ self.eps = 2**-24
55
  self.denorm = (2 ** config.num_feat_extract_layers) * self.spec.hop_length / config.sampling_rate
56
  self.scaler = config.hidden_size ** (1/2)
57
  return super().init_weights()
 
70
  def preprocessing(self, x):
71
  x, l = x
72
  l = (l // self.spec.hop_length + 1).long()
73
+ x = torch.cat((x[:, :1], x[:, 1:] - self.config.preemph * x[:, :-1]), 1)
74
  x = (self.mel_fb @ self.spec(x) + self.eps).log()
75
  T = x.size(-1)
76
  m = torch.arange(T, device=x.device)[None] >= l[:, None]
 
80
  σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
81
  x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
82
  self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
83
+ return F.pad(x, (0, (-T) % self.config.pad_to)).transpose(1, 2)
84
 
85
  def forward(self, input_values):
86
+ return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def postprocessing(self, enc_out):
89
+ B, T, _ = enc_out.shape
90
+ H = self.lstm.hidden_size
91
+ blank = self.config.blank_id
92
+ pad = self.config.pad_id
93
+ max_len = T * self.config.max_symbols_per_step
94
 
95
+ tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
96
+ starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
97
+ lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
98
+ hx = torch.zeros(1, B, H, dtype=enc_out.dtype, device=enc_out.device)
99
+ cx = torch.zeros_like(hx)
100
+ last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
101
 
102
+ enc_proj = self.enc(enc_out) # (B, T, D)
 
103
 
104
+ for t in range(T):
105
+ e = enc_proj[:, t:t+1]
106
+ t_sec = torch.full((B, 1), t * self.denorm, dtype=enc_out.dtype, device=enc_out.device)
 
107
 
108
+ for _ in range(self.config.max_symbols_per_step):
109
+ hx_prev, cx_prev = hx, cx
110
 
111
+ p, (hx, cx) = self.lstm(self.embed(last), (hx, cx))
112
+ n = self.joint(self.act(e + self.pred(p))).squeeze(1).argmax(-1) # (B,)
113
+ emitted = n.ne(blank)
114
 
115
+ # revert hidden for blanks
116
+ mask = emitted.view(1, B, 1)
117
+ hx = torch.where(mask, hx, hx_prev)
118
+ cx = torch.where(mask, cx, cx_prev)
119
+ last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
 
 
 
120
 
121
+ pos = lengths.unsqueeze(1).clamp(max=max_len - 1)
122
+ fill_t = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), torch.full_like(n.unsqueeze(1), pad))
123
+ fill_s = torch.where(emitted.unsqueeze(1), t_sec, torch.full_like(t_sec, -1.0))
124
 
125
+ tokens = tokens.scatter(1, pos, fill_t)
126
+ starts = starts.scatter(1, pos, fill_s)
127
+ lengths = lengths + emitted.long()
 
128
 
129
  return tokens, starts
130