shethjenil commited on
Commit
f3c2e37
·
verified ·
1 Parent(s): 83bd703

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +13 -151
modeling_conformer.py CHANGED
@@ -1,96 +1,13 @@
1
- from datetime import timedelta
2
- import json
3
  from huggingface_hub import hf_hub_download
 
 
 
 
 
4
  import torch
5
  import torch.nn.functional as F
6
  import torchaudio
7
  import librosa
8
- from torch import nn
9
- from transformers import Wav2Vec2ConformerModel
10
- from torch_state_bridge import state_bridge
11
- from torch.nn.utils.rnn import pad_sequence
12
- from safetensors.torch import load_file
13
- import webrtcvad
14
- from torch.utils.data import Dataset , DataLoader
15
- import srt
16
-
17
- class ChunkedData(Dataset):
18
- def __init__(self, wav, sr):
19
- if sr != 16000:
20
- wav = torchaudio.functional.resample(wav, sr, 16000)
21
-
22
- self.wav = wav.mean(0, keepdim=True)
23
- self.sr = 16000
24
-
25
- # Sirf timestamps store karo, actual chunk nahi
26
- self.ts = self.make_chunk_timestamps(self.wav)
27
-
28
- def __len__(self):
29
- return len(self.ts)
30
-
31
- def __getitem__(self, i):
32
- st, ed = self.ts[i]
33
- st_i = int(st * self.sr)
34
- ed_i = int(ed * self.sr)
35
- chunk = self.wav[:, st_i:ed_i].squeeze()
36
- return chunk, self.ts[i]
37
-
38
- def make_chunk_timestamps(self, wav, sr=16000, ag=2, min_s=10, max_s=15, ms=30):
39
-
40
- wav_int16 = (wav * 32768).clamp(-32768, 32767).short().squeeze(0)
41
-
42
- frame_len = int(sr * ms / 1000)
43
- num_frames = len(wav_int16) // frame_len
44
- wav_int16 = wav_int16[: num_frames * frame_len]
45
-
46
- frames = wav_int16.view(num_frames, frame_len)
47
-
48
- vad = webrtcvad.Vad(ag)
49
- speech = torch.tensor(
50
- [vad.is_speech(frame.numpy().tobytes(), sr) for frame in frames],
51
- dtype=torch.bool
52
- )
53
-
54
- timestamps = []
55
- total_samples = len(wav_int16)
56
-
57
- min_len = int(min_s * sr)
58
- max_len = int(max_s * sr)
59
-
60
- st = 0
61
-
62
- while st < total_samples:
63
- ed = min(st + max_len, total_samples)
64
-
65
- if ed - st < min_len and ed < total_samples:
66
- ed = min(st + min_len, total_samples)
67
-
68
- timestamps.append((
69
- round(st / sr, 2),
70
- round(ed / sr, 2)
71
- ))
72
-
73
- st = ed
74
-
75
- return timestamps
76
-
77
-
78
-
79
- def padding_audio(batch):
80
- audios, times = zip(*batch)
81
-
82
- lengths = torch.tensor([audio.numel() for audio in audios])
83
- times = torch.tensor(times, dtype=torch.float32)
84
-
85
- padded = pad_sequence(audios, batch_first=True)
86
-
87
- return padded, lengths, times
88
-
89
- def calc_length(lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
90
- add_pad = all_paddings - kernel_size
91
- for _ in range(repeat_num):
92
- lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
93
- return lengths
94
 
95
  class Op(nn.Module):
96
  def __init__(self, func,allow_self=False):
@@ -155,7 +72,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
155
  l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
156
 
157
  self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
158
- self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * int(calc_length(torch.tensor(80.),repeat_num=self.config.num_feat_extract_layers)),config.hidden_size)
159
  self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
160
  for l in self.encoder.layers:
161
  l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
@@ -172,6 +89,12 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
172
  self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
173
  return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
174
 
 
 
 
 
 
 
175
  def preprocessing(self, x):
176
  x, l = x
177
  l = (l // self.hop + 1).long()
@@ -184,25 +107,12 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
184
  denom = torch.clamp(l[:, None] - 1, min=1)
185
  σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
186
  x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
187
- self.cache_length = calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
188
  return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
189
 
190
  def forward(self, input_values):
191
  return self._greedy_decode(super().forward(self.preprocessing(input_values)).last_hidden_state)
192
 
193
- @torch.inference_mode()
194
- def transcribe(self,wav,sr,batch_size):
195
- device = next(self.parameters()).device
196
- subtitles = []
197
- for batch, lengths, timestamp in DataLoader(ChunkedData(wav, sr),batch_size,collate_fn=padding_audio):
198
- batch = batch.to(device)
199
- lengths = lengths.to(device)
200
- timestamp = timestamp.to(device)
201
- subtitles.extend(self.make_srt(self.forward((batch, lengths)),timestamp))
202
- yield srt.compose(subtitles)
203
- del batch
204
- del lengths
205
-
206
  def load_state_dict(self, state_dict, strict=True, assign=False):
207
  state_dict.pop('ctc_decoder.decoder_layers.0.bias', None)
208
  state_dict.pop('ctc_decoder.decoder_layers.0.weight', None)
@@ -320,60 +230,12 @@ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv
320
 
321
  return tokens, starts
322
 
323
- def make_srt(self, decoded, ts):
324
-
325
- tokens_list, starts_list = decoded
326
-
327
- start_token_segment = (
328
- self.config.languages.index(self.language)
329
- * self.joint.out_features
330
- )
331
-
332
- all_tokens = []
333
- all_starts = []
334
- all_ends = []
335
-
336
- device = tokens_list[0].device
337
-
338
- for tokens, starts, (seg_start, seg_end) in zip(
339
- tokens_list, starts_list, ts):
340
-
341
- tokens = tokens + start_token_segment
342
- starts = starts + seg_start
343
-
344
- all_tokens.append(tokens)
345
- all_starts.append(starts)
346
- all_ends.append(torch.cat([starts[1:], seg_end[None]]))
347
-
348
- # newline marker
349
- all_tokens.append(torch.tensor([-1], device=device))
350
- all_starts.append(torch.tensor([seg_end], device=device))
351
- all_ends.append(torch.tensor([seg_end + 0.005], device=device))
352
-
353
- return [
354
- srt.Subtitle(
355
- i,
356
- timedelta(seconds=float(st)),
357
- timedelta(seconds=float(en)),
358
- "<line>" if tok == -1 else self.config.vocab[int(tok)]
359
- )
360
- for i, (tok, st, en) in enumerate(
361
- zip(
362
- torch.cat(all_tokens),
363
- torch.cat(all_starts),
364
- torch.cat(all_ends)
365
- ), 1
366
- )
367
- ]
368
-
369
-
370
  @classmethod
371
  def from_pretrained(
372
  cls,
373
  pretrained_model_name_or_path,
374
  config=None,
375
  language=None,
376
- use_jit=False,
377
  use_quantization=False):
378
 
379
  if config is None:
 
 
 
1
  from huggingface_hub import hf_hub_download
2
+ from torch import nn
3
+ from transformers import Wav2Vec2ConformerModel
4
+ from safetensors.torch import load_file
5
+ from torch_state_bridge import state_bridge
6
+ import json
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
10
  import librosa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class Op(nn.Module):
13
  def __init__(self, func,allow_self=False):
 
72
  l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
73
 
74
  self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
75
+ self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * int(self.calc_length(torch.tensor(80.),repeat_num=self.config.num_feat_extract_layers)),config.hidden_size)
76
  self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
77
  for l in self.encoder.layers:
78
  l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
 
89
  self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
90
  return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
91
 
92
+ def calc_length(self,lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
93
+ add_pad = all_paddings - kernel_size
94
+ for _ in range(repeat_num):
95
+ lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
96
+ return lengths
97
+
98
  def preprocessing(self, x):
99
  x, l = x
100
  l = (l // self.hop + 1).long()
 
107
  denom = torch.clamp(l[:, None] - 1, min=1)
108
  σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
109
  x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
110
+ self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
111
  return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
112
 
113
  def forward(self, input_values):
114
  return self._greedy_decode(super().forward(self.preprocessing(input_values)).last_hidden_state)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def load_state_dict(self, state_dict, strict=True, assign=False):
117
  state_dict.pop('ctc_decoder.decoder_layers.0.bias', None)
118
  state_dict.pop('ctc_decoder.decoder_layers.0.weight', None)
 
230
 
231
  return tokens, starts
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  @classmethod
234
  def from_pretrained(
235
  cls,
236
  pretrained_model_name_or_path,
237
  config=None,
238
  language=None,
 
239
  use_quantization=False):
240
 
241
  if config is None: