shethjenil commited on
Commit
83bd703
·
verified ·
1 Parent(s): fd5e502

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +233 -96
modeling_conformer.py CHANGED
@@ -1,5 +1,4 @@
1
  from datetime import timedelta
2
- import gc
3
  import json
4
  from huggingface_hub import hf_hub_download
5
  import torch
@@ -15,56 +14,83 @@ import webrtcvad
15
  from torch.utils.data import Dataset , DataLoader
16
  import srt
17
 
18
- def calc_length(lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
19
- add_pad = all_paddings - kernel_size
20
- for _ in range(repeat_num):
21
- lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
22
- return lengths
23
-
24
  class ChunkedData(Dataset):
25
  def __init__(self, wav, sr):
26
- if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000)
27
- wav = wav.mean(0, keepdim=True)
28
- self.data, self.ts = self.make_chunks(wav)
29
-
30
- def __len__(self): return len(self.data)
31
- def __getitem__(self, i): return self.data[i], self.ts[i]
32
-
33
- def make_chunks(self, wav, sr=16000, ag=2, min_s=10, max_s=15, ms=30):
34
- w = (wav * 32768).clamp(-32768, 32767).short().squeeze(0)
35
- fl = int(sr * ms / 1000)
36
- nf = len(w) // fl
37
- w = w[: nf * fl]
38
- fr = w.view(nf, fl)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  vad = webrtcvad.Vad(ag)
40
- sp = torch.zeros(nf, dtype=torch.bool)
41
- for i, f in enumerate(fr):
42
- try: sp[i] = vad.is_speech(f.cpu().numpy().tobytes(), sr)
43
- except: pass
44
- seg, s = [], None
45
- for i, v in enumerate(sp):
46
- if v and s is None: s = i
47
- elif not v and s is not None: seg.append((s, i)); s = None
48
- if s is not None: seg.append((s, len(sp)))
49
- cs, ts, st = [], [], 0
50
- mn, mx, N = int(min_s * sr), int(max_s * sr), len(w)
51
- while st < N:
52
- ed = min(st + mx, N)
53
- f = ed // fl
54
- while f < len(sp) and sp[f]:
55
- f += 1; ed = min(f * fl, N)
56
- if ed - st > mx * 1.5: break
57
- if ed - st < mn and ed < N: ed = min(st + mn, N)
58
- cs.append(wav[:, st:ed].squeeze())
59
- ts.append([round(st / sr, 2), round(ed / sr, 2)])
 
 
 
 
60
  st = ed
61
- return cs, torch.tensor(ts)
 
62
 
63
 
64
 
65
  def padding_audio(batch):
66
  audios, times = zip(*batch)
67
- return pad_sequence(audios, batch_first=True), torch.tensor([audio.numel() for audio in audios]), torch.stack(times)
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  class Op(nn.Module):
70
  def __init__(self, func,allow_self=False):
@@ -110,7 +136,16 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
110
  self.act = nn.ReLU(inplace=True)
111
  self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
112
  self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
113
- self.mel_fb = nn.Parameter(torch.tensor(librosa.filters.mel(sr=self.config.sampling_rate, n_fft=512, n_mels=80)),False)
 
 
 
 
 
 
 
 
 
114
 
115
  for idx,l in enumerate(self.feature_extractor.conv_layers):
116
  if len(self.config.languages) == 1 or idx == 0:
@@ -146,13 +181,14 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
146
  m = torch.arange(T, device=x.device)[None] >= l[:, None]
147
  x = x.masked_fill(m[:, None], 0)
148
  μ = x.sum(-1) / l[:, None]
149
- σ = (((x - μ[..., None])**2).sum(-1) / (l[:, None] - 1) + 1e-5).sqrt()
 
150
  x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
151
  self.cache_length = calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
152
  return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
153
 
154
  def forward(self, input_values):
155
- return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state)
156
 
157
  @torch.inference_mode()
158
  def transcribe(self,wav,sr,batch_size):
@@ -164,12 +200,13 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
164
  timestamp = timestamp.to(device)
165
  subtitles.extend(self.make_srt(self.forward((batch, lengths)),timestamp))
166
  yield srt.compose(subtitles)
167
- torch.cuda.empty_cache()
168
- gc.collect()
169
 
170
  def load_state_dict(self, state_dict, strict=True, assign=False):
171
- del state_dict['ctc_decoder.decoder_layers.0.bias']
172
- del state_dict['ctc_decoder.decoder_layers.0.weight']
 
173
  state_dict['preprocessor.featurizer.fb'] = state_dict['preprocessor.featurizer.fb'].squeeze(0)
174
  changes = """
175
  preprocessor.featurizer.fb,mel_fb
@@ -207,61 +244,161 @@ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv
207
  state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
208
  return super().load_state_dict(state_dict, strict, assign)
209
 
210
- def postprocessing(self, x):
211
- if len(self.config.languages) > 1:
212
- self.joint.load_state_dict(self.lang_joint_net[self.language].state_dict())
213
- B = x.size(0)
214
- last = x.new_full((B, 1), self.config.blank_id, dtype=torch.long)
215
- h, tok, st = None, [[] for _ in range(B)], [[] for _ in range(B)]
216
- for t, e in enumerate(x.unbind(1)):
217
- v = t < self.cache_length
218
- if not v.any(): break
219
- e = e[:, None]
220
- for _ in range(self.config.max_symbols_per_step):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  p, h2 = self.lstm(self.embed(last), h)
222
- lg = self.joint(self.act(self.enc(e) + self.pred(p))).squeeze(1)
223
- n = torch.where(v, lg.argmax(-1), self.config.blank_id)
224
- b = n.eq(self.config.blank_id)
225
- if b.all(): break
226
- a = v & ~b
227
- for i in a.nonzero().flatten().tolist():
228
- tok[i].append(n[i]); st[i].append(t * self.denorm)
229
- last = torch.where(a[:, None], n[:, None], last)
230
- if h is None: h = h2
 
 
 
 
 
 
 
 
 
 
 
231
  else:
232
- k = (b | ~v).view(1, -1, 1)
233
- h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
234
- self.cache_length = None
235
- device = next(self.parameters()).device
236
- return [torch.tensor(i,device=device) for i in tok], [torch.tensor(i,device=device) for i in st]
237
-
238
- def make_srt(self, x, ts):
239
- t , s = x
240
- start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
241
- all_tokens, all_starts, all_ends = [], [], []
242
- device = t[0].device
243
- for tokens, starts, (s, e) in zip(t,s, ts):
244
- tokens += start_token_segment
245
- starts += s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  all_tokens.append(tokens)
247
  all_starts.append(starts)
248
- all_ends.append(torch.cat([starts[1:], e[None]]))
249
- all_tokens.append(torch.tensor([-1],device=device))
250
- all_starts.append(torch.tensor([e],device=device))
251
- all_ends.append(torch.tensor([e + 0.005],device=device))
252
- return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  @classmethod
256
- def from_pretrained(cls, pretrained_model_name_or_path, config = None, language=None,**kwargs):
 
 
 
 
 
 
 
 
 
 
257
  if language:
258
  config.languages = [language]
259
- config.vocab = ['<unk>'] + json.load(open(hf_hub_download(pretrained_model_name_or_path, "vocab.json")))['small'][language]
260
- else:
261
- temp_vocab = json.load(open(hf_hub_download(pretrained_model_name_or_path, "vocab.json")))['large']
262
- config.vocab = []
263
- for i in sorted(config.languages):
264
- config.vocab.extend(['<unk>'] + temp_vocab[i])
 
 
 
265
  model = cls(config)
266
- model.load_state_dict(load_file(hf_hub_download(pretrained_model_name_or_path, f"{language or 'all'}.safetensors")))
 
 
 
 
 
 
 
 
 
267
  return model
 
1
  from datetime import timedelta
 
2
  import json
3
  from huggingface_hub import hf_hub_download
4
  import torch
 
14
  from torch.utils.data import Dataset , DataLoader
15
  import srt
16
 
 
 
 
 
 
 
17
  class ChunkedData(Dataset):
18
  def __init__(self, wav, sr):
19
+ if sr != 16000:
20
+ wav = torchaudio.functional.resample(wav, sr, 16000)
21
+
22
+ self.wav = wav.mean(0, keepdim=True)
23
+ self.sr = 16000
24
+
25
+ # Sirf timestamps store karo, actual chunk nahi
26
+ self.ts = self.make_chunk_timestamps(self.wav)
27
+
28
+ def __len__(self):
29
+ return len(self.ts)
30
+
31
+ def __getitem__(self, i):
32
+ st, ed = self.ts[i]
33
+ st_i = int(st * self.sr)
34
+ ed_i = int(ed * self.sr)
35
+ chunk = self.wav[:, st_i:ed_i].squeeze()
36
+ return chunk, self.ts[i]
37
+
38
+ def make_chunk_timestamps(self, wav, sr=16000, ag=2, min_s=10, max_s=15, ms=30):
39
+
40
+ wav_int16 = (wav * 32768).clamp(-32768, 32767).short().squeeze(0)
41
+
42
+ frame_len = int(sr * ms / 1000)
43
+ num_frames = len(wav_int16) // frame_len
44
+ wav_int16 = wav_int16[: num_frames * frame_len]
45
+
46
+ frames = wav_int16.view(num_frames, frame_len)
47
+
48
  vad = webrtcvad.Vad(ag)
49
+ speech = torch.tensor(
50
+ [vad.is_speech(frame.numpy().tobytes(), sr) for frame in frames],
51
+ dtype=torch.bool
52
+ )
53
+
54
+ timestamps = []
55
+ total_samples = len(wav_int16)
56
+
57
+ min_len = int(min_s * sr)
58
+ max_len = int(max_s * sr)
59
+
60
+ st = 0
61
+
62
+ while st < total_samples:
63
+ ed = min(st + max_len, total_samples)
64
+
65
+ if ed - st < min_len and ed < total_samples:
66
+ ed = min(st + min_len, total_samples)
67
+
68
+ timestamps.append((
69
+ round(st / sr, 2),
70
+ round(ed / sr, 2)
71
+ ))
72
+
73
  st = ed
74
+
75
+ return timestamps
76
 
77
 
78
 
79
  def padding_audio(batch):
80
  audios, times = zip(*batch)
81
+
82
+ lengths = torch.tensor([audio.numel() for audio in audios])
83
+ times = torch.tensor(times, dtype=torch.float32)
84
+
85
+ padded = pad_sequence(audios, batch_first=True)
86
+
87
+ return padded, lengths, times
88
+
89
+ def calc_length(lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
90
+ add_pad = all_paddings - kernel_size
91
+ for _ in range(repeat_num):
92
+ lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
93
+ return lengths
94
 
95
  class Op(nn.Module):
96
  def __init__(self, func,allow_self=False):
 
136
  self.act = nn.ReLU(inplace=True)
137
  self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
138
  self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
139
+ self.register_buffer(
140
+ "mel_fb",
141
+ torch.tensor(
142
+ librosa.filters.mel(
143
+ sr=self.config.sampling_rate,
144
+ n_fft=512,
145
+ n_mels=80
146
+ )
147
+ )
148
+ )
149
 
150
  for idx,l in enumerate(self.feature_extractor.conv_layers):
151
  if len(self.config.languages) == 1 or idx == 0:
 
181
  m = torch.arange(T, device=x.device)[None] >= l[:, None]
182
  x = x.masked_fill(m[:, None], 0)
183
  μ = x.sum(-1) / l[:, None]
184
+ denom = torch.clamp(l[:, None] - 1, min=1)
185
+ σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
186
  x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
187
  self.cache_length = calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
188
  return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
189
 
190
  def forward(self, input_values):
191
+ return self._greedy_decode(super().forward(self.preprocessing(input_values)).last_hidden_state)
192
 
193
  @torch.inference_mode()
194
  def transcribe(self,wav,sr,batch_size):
 
200
  timestamp = timestamp.to(device)
201
  subtitles.extend(self.make_srt(self.forward((batch, lengths)),timestamp))
202
  yield srt.compose(subtitles)
203
+ del batch
204
+ del lengths
205
 
206
  def load_state_dict(self, state_dict, strict=True, assign=False):
207
+ state_dict.pop('ctc_decoder.decoder_layers.0.bias', None)
208
+ state_dict.pop('ctc_decoder.decoder_layers.0.weight', None)
209
+
210
  state_dict['preprocessor.featurizer.fb'] = state_dict['preprocessor.featurizer.fb'].squeeze(0)
211
  changes = """
212
  preprocessor.featurizer.fb,mel_fb
 
244
  state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
245
  return super().load_state_dict(state_dict, strict, assign)
246
 
247
+ @torch.jit.export
248
+ def _greedy_decode(self, enc_out: torch.Tensor):
249
+
250
+ B, T, _ = enc_out.size()
251
+ device = enc_out.device
252
+
253
+ enc_proj = self.enc(enc_out)
254
+
255
+ max_symbols = self.config.max_symbols_per_step
256
+ max_len = T * max_symbols
257
+
258
+ token_buffer = torch.full(
259
+ (B, max_len),
260
+ -1,
261
+ dtype=torch.long,
262
+ device=device
263
+ )
264
+
265
+ start_buffer = torch.zeros(
266
+ (B, max_len),
267
+ device=device
268
+ )
269
+
270
+ lengths = torch.zeros(B, dtype=torch.long, device=device)
271
+
272
+ last = torch.full(
273
+ (B, 1),
274
+ self.config.blank_id,
275
+ dtype=torch.long,
276
+ device=device
277
+ )
278
+
279
+ h = None
280
+
281
+ for t in range(T):
282
+ e = enc_proj[:, t:t+1]
283
+
284
+ for _ in range(max_symbols):
285
  p, h2 = self.lstm(self.embed(last), h)
286
+ joint = self.joint(self.act(e + self.pred(p))).squeeze(1)
287
+
288
+ n = joint.argmax(-1)
289
+ blank = n.eq(self.config.blank_id)
290
+ emit_mask = ~blank
291
+
292
+ if not emit_mask.any():
293
+ break
294
+
295
+ pos = lengths[emit_mask]
296
+
297
+ token_buffer[emit_mask, pos] = n[emit_mask]
298
+ start_buffer[emit_mask, pos] = t * self.denorm
299
+
300
+ lengths[emit_mask] += 1
301
+
302
+ last = torch.where(emit_mask[:, None], n[:, None], last)
303
+
304
+ if h is None:
305
+ h = h2
306
  else:
307
+ keep_mask = blank.view(1, -1, 1)
308
+ h = (
309
+ torch.where(keep_mask, h[0], h2[0]),
310
+ torch.where(keep_mask, h[1], h2[1]),
311
+ )
312
+
313
+ tokens = []
314
+ starts = []
315
+
316
+ for b in range(B):
317
+ L = lengths[b]
318
+ tokens.append(token_buffer[b, :L])
319
+ starts.append(start_buffer[b, :L])
320
+
321
+ return tokens, starts
322
+
323
+ def make_srt(self, decoded, ts):
324
+
325
+ tokens_list, starts_list = decoded
326
+
327
+ start_token_segment = (
328
+ self.config.languages.index(self.language)
329
+ * self.joint.out_features
330
+ )
331
+
332
+ all_tokens = []
333
+ all_starts = []
334
+ all_ends = []
335
+
336
+ device = tokens_list[0].device
337
+
338
+ for tokens, starts, (seg_start, seg_end) in zip(
339
+ tokens_list, starts_list, ts):
340
+
341
+ tokens = tokens + start_token_segment
342
+ starts = starts + seg_start
343
+
344
  all_tokens.append(tokens)
345
  all_starts.append(starts)
346
+ all_ends.append(torch.cat([starts[1:], seg_end[None]]))
347
+
348
+ # newline marker
349
+ all_tokens.append(torch.tensor([-1], device=device))
350
+ all_starts.append(torch.tensor([seg_end], device=device))
351
+ all_ends.append(torch.tensor([seg_end + 0.005], device=device))
352
+
353
+ return [
354
+ srt.Subtitle(
355
+ i,
356
+ timedelta(seconds=float(st)),
357
+ timedelta(seconds=float(en)),
358
+ "<line>" if tok == -1 else self.config.vocab[int(tok)]
359
+ )
360
+ for i, (tok, st, en) in enumerate(
361
+ zip(
362
+ torch.cat(all_tokens),
363
+ torch.cat(all_starts),
364
+ torch.cat(all_ends)
365
+ ), 1
366
+ )
367
+ ]
368
 
369
 
370
  @classmethod
371
+ def from_pretrained(
372
+ cls,
373
+ pretrained_model_name_or_path,
374
+ config=None,
375
+ language=None,
376
+ use_jit=False,
377
+ use_quantization=False):
378
+
379
+ if config is None:
380
+ raise ValueError("config must be provided")
381
+
382
  if language:
383
  config.languages = [language]
384
+
385
+ vocab_file = hf_hub_download(
386
+ pretrained_model_name_or_path,
387
+ "vocab.json"
388
+ )
389
+
390
+ vocab_json = json.load(open(vocab_file))
391
+ config.vocab = ['<unk>'] + vocab_json['small'][language]
392
+
393
  model = cls(config)
394
+
395
+ weight_file = hf_hub_download(
396
+ pretrained_model_name_or_path,
397
+ f"{language or 'all'}.safetensors"
398
+ )
399
+
400
+ model.load_state_dict(load_file(weight_file))
401
+ if use_quantization:
402
+ model = torch.quantization.quantize_dynamic(model)
403
+
404
  return model