stukenov commited on
Commit
77ec394
·
verified ·
1 Parent(s): 7768093

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +41 -207
app.py CHANGED
@@ -1,17 +1,12 @@
1
- """
2
- SozKZ -- Kazakh ASR Demo
3
- OmniAudio v2 Scratch 70M
4
- """
5
 
6
  import os
7
- import math
8
  import spaces
9
  import gradio as gr
10
  import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
  import numpy as np
14
  import librosa
 
15
  import time
16
  from transformers import PreTrainedTokenizerFast
17
  from huggingface_hub import hf_hub_download, login
@@ -20,230 +15,68 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
20
  if HF_TOKEN:
21
  login(token=HF_TOKEN)
22
 
23
- # -- Model (exact names matching model.pt state_dict) --
24
-
25
- class RotaryEmbedding(nn.Module):
26
- def __init__(self, dim, base=10000.0):
27
- super().__init__()
28
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
29
- self.register_buffer("inv_freq", inv_freq)
30
- def forward(self, seq_len):
31
- t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
32
- freqs = torch.outer(t, self.inv_freq)
33
- emb = torch.cat([freqs, freqs], dim=-1)
34
- return emb.cos(), emb.sin()
35
-
36
- def _rotate_half(x):
37
- x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
38
- return torch.cat([-x2, x1], dim=-1)
39
-
40
- def apply_rotary_emb(x, cos, sin):
41
- s = x.shape[2]
42
- return x * cos[:s].unsqueeze(0).unsqueeze(0) + _rotate_half(x) * sin[:s].unsqueeze(0).unsqueeze(0)
43
-
44
- class RMSNorm(nn.Module):
45
- def __init__(self, dim, eps=1e-6):
46
- super().__init__()
47
- self.eps = eps
48
- self.weight = nn.Parameter(torch.ones(dim))
49
- def forward(self, x):
50
- return (x.float() * x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()).to(x.dtype) * self.weight
51
-
52
- class EncoderBlock(nn.Module):
53
- def __init__(self, d_model, n_heads, dropout=0.1):
54
- super().__init__()
55
- self.n_heads = n_heads
56
- self.head_dim = d_model // n_heads
57
- self.norm1 = RMSNorm(d_model)
58
- self.norm2 = RMSNorm(d_model)
59
- self.q_proj = nn.Linear(d_model, d_model)
60
- self.k_proj = nn.Linear(d_model, d_model)
61
- self.v_proj = nn.Linear(d_model, d_model)
62
- self.o_proj = nn.Linear(d_model, d_model)
63
- self.rope = RotaryEmbedding(self.head_dim)
64
- inter = int(d_model * 8 / 3)
65
- inter = ((inter + 63) // 64) * 64
66
- self.gate_proj = nn.Linear(d_model, inter, bias=False)
67
- self.up_proj = nn.Linear(d_model, inter, bias=False)
68
- self.down_proj = nn.Linear(inter, d_model, bias=False)
69
- self.dropout = nn.Dropout(dropout)
70
-
71
- def forward(self, x):
72
- B, T, C = x.shape
73
- h = self.norm1(x)
74
- q = self.q_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
75
- k = self.k_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
76
- v = self.v_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
77
- cos, sin = self.rope(T)
78
- q = apply_rotary_emb(q, cos, sin)
79
- k = apply_rotary_emb(k, cos, sin)
80
- attn = F.scaled_dot_product_attention(q, k, v)
81
- x = x + self.dropout(self.o_proj(attn.transpose(1, 2).contiguous().view(B, T, C)))
82
- h = self.norm2(x)
83
- x = x + self.dropout(self.down_proj(F.silu(self.gate_proj(h)) * self.up_proj(h)))
84
- return x
85
-
86
- class AudioEncoder(nn.Module):
87
- def __init__(self, n_mels=80, d_model=256, n_heads=4, n_layers=6, n_conv=2, dropout=0.1):
88
- super().__init__()
89
- convs = []
90
- inch = n_mels
91
- for i in range(n_conv):
92
- convs += [nn.Conv1d(inch, d_model, 3, 2, 1), nn.SiLU(), nn.Dropout(dropout)]
93
- inch = d_model
94
- self.conv_stack = nn.Sequential(*convs)
95
- self.layers = nn.ModuleList([EncoderBlock(d_model, n_heads, dropout) for _ in range(n_layers)])
96
- self.norm = RMSNorm(d_model)
97
-
98
- def forward(self, mel):
99
- x = self.conv_stack(mel).transpose(1, 2)
100
- for layer in self.layers:
101
- x = layer(x)
102
- return self.norm(x)
103
-
104
- class DecoderBlock(nn.Module):
105
- def __init__(self, d_model, n_heads, dropout=0.1):
106
- super().__init__()
107
- self.n_heads = n_heads
108
- self.head_dim = d_model // n_heads
109
- self.norm1 = RMSNorm(d_model)
110
- self.norm2 = RMSNorm(d_model)
111
- self.q_proj = nn.Linear(d_model, d_model)
112
- self.k_proj = nn.Linear(d_model, d_model)
113
- self.v_proj = nn.Linear(d_model, d_model)
114
- self.o_proj = nn.Linear(d_model, d_model)
115
- inter = int(d_model * 8 / 3)
116
- inter = ((inter + 63) // 64) * 64
117
- self.gate_proj = nn.Linear(d_model, inter, bias=False)
118
- self.up_proj = nn.Linear(d_model, inter, bias=False)
119
- self.down_proj = nn.Linear(inter, d_model, bias=False)
120
- self.dropout = nn.Dropout(dropout)
121
-
122
- def forward(self, x, cos, sin):
123
- B, T, C = x.shape
124
- h = self.norm1(x)
125
- q = self.q_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
126
- k = self.k_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
127
- v = self.v_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
128
- q = apply_rotary_emb(q, cos, sin)
129
- k = apply_rotary_emb(k, cos, sin)
130
- attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
131
- x = x + self.dropout(self.o_proj(attn.transpose(1, 2).contiguous().view(B, T, C)))
132
- h = self.norm2(x)
133
- x = x + self.dropout(self.down_proj(F.silu(self.gate_proj(h)) * self.up_proj(h)))
134
- return x
135
-
136
- class AudioProjectorV2(nn.Module):
137
- def __init__(self, audio_dim, llm_dim):
138
- super().__init__()
139
- self.linear = nn.Linear(audio_dim, llm_dim)
140
- self.norm = RMSNorm(llm_dim)
141
- def forward(self, x):
142
- return self.norm(self.linear(x))
143
-
144
-
145
- class OmniAudioScratchModel(nn.Module):
146
- def __init__(self, encoder_config, decoder_config, vocab_size=50257, dropout=0.1):
147
- super().__init__()
148
- enc_dim = encoder_config["d_model"]
149
- dec_dim = decoder_config["d_model"]
150
- self.encoder = AudioEncoder(**encoder_config, dropout=dropout)
151
- self.projector = AudioProjectorV2(enc_dim, dec_dim)
152
- self.embed_tokens = nn.Embedding(vocab_size, dec_dim)
153
- self.decoder_layers = nn.ModuleList([
154
- DecoderBlock(dec_dim, decoder_config["n_heads"], dropout)
155
- for _ in range(decoder_config["n_layers"])
156
- ])
157
- self.decoder_norm = RMSNorm(dec_dim)
158
- self.decoder_rope = RotaryEmbedding(dec_dim // decoder_config["n_heads"])
159
- self.lm_head = nn.Linear(dec_dim, vocab_size, bias=False)
160
- # CTC head (may exist in checkpoint, not used for inference)
161
- self.ctc_head = nn.Linear(enc_dim, vocab_size)
162
-
163
- def generate(self, mel, max_new_tokens=200, eos_token_id=0, repetition_penalty=1.2):
164
- enc_out = self.encoder(mel)
165
- audio_embeds = self.projector(enc_out)
166
- generated = []
167
- combined = audio_embeds
168
- for _ in range(max_new_tokens):
169
- cos, sin = self.decoder_rope(combined.size(1))
170
- x = combined
171
- for layer in self.decoder_layers:
172
- x = layer(x, cos, sin)
173
- logits = self.lm_head(self.decoder_norm(x)[:, -1:]).squeeze(0).squeeze(0)
174
- if repetition_penalty != 1.0 and generated:
175
- for t in set(generated):
176
- if logits[t] > 0:
177
- logits[t] /= repetition_penalty
178
- else:
179
- logits[t] *= repetition_penalty
180
- tok = logits.argmax(-1).item()
181
- if tok == eos_token_id:
182
- break
183
- generated.append(tok)
184
- combined = torch.cat([combined, self.embed_tokens(torch.tensor([[tok]], device=mel.device))], dim=1)
185
- return generated
186
-
187
-
188
- # Mel filterbank extracted from torchaudio (exact match, 0.0 diff)
189
- MEL_FB = torch.load(hf_hub_download("stukenov/sozkz-core-omniaudio-70m-kk-asr-v2", "mel_filterbank.pt", token=HF_TOKEN), map_location="cpu", weights_only=True)
190
  MEL_WINDOW = torch.hann_window(400)
191
 
192
- def compute_mel(wav_np, sr=16000, n_fft=400, hop=160):
193
- """Compute log-mel spectrogram matching torchaudio exactly."""
194
  wav = torch.from_numpy(wav_np).float()
195
- stft = torch.stft(wav, n_fft=n_fft, hop_length=hop, win_length=n_fft,
196
  window=MEL_WINDOW, center=True, pad_mode="reflect", return_complex=True)
197
- power = stft.abs().pow(2) # (n_freqs, T)
198
- mel = torch.matmul(MEL_FB.T, power) # (80, T)
199
- return torch.log(torch.clamp(mel, min=1e-10)).unsqueeze(0) # (1, 80, T)
200
 
201
 
202
- # -- Load --
203
-
204
  ASR_MODELS = {
205
  "v2 (CTC+CE)": "stukenov/sozkz-core-omniaudio-70m-kk-asr-v2",
206
  "v1 (pure CE)": "stukenov/sozkz-core-omniaudio-70m-kk-asr-v1",
207
  }
208
- TOK_REPO = "stukenov/sozkz-core-gpt2-50k-kk-base-v1"
 
209
 
210
- print("Loading tokenizer...")
211
  tok_file = hf_hub_download(TOK_REPO, "tokenizer.json")
212
  tokenizer = PreTrainedTokenizerFast(tokenizer_file=tok_file)
213
  tokenizer.eos_token = "<|endoftext|>"
214
  tokenizer.eos_token_id = 0
215
 
216
- def load_asr(repo):
217
- print(f"Loading ASR from {repo}...")
218
- model = OmniAudioScratchModel(
219
- encoder_config={"n_mels": 80, "d_model": 256, "n_heads": 4, "n_layers": 6, "n_conv": 2},
220
- decoder_config={"d_model": 512, "n_heads": 8, "n_layers": 8},
221
- vocab_size=50257,
222
  )
223
  w = hf_hub_download(repo, "model.pt")
224
  sd = torch.load(w, map_location="cpu", weights_only=True)
225
- missing, unexpected = model.load_state_dict(sd, strict=False)
226
- model.lm_head.weight = model.embed_tokens.weight
227
- model.requires_grad_(False)
228
- params = sum(p.numel() for p in model.parameters()) / 1e6
229
- print(f" {params:.0f}M params, missing: {len(missing)}, unexpected: {len(unexpected)}")
230
- return model
231
-
232
- loaded_asr = {}
233
- for name, repo in ASR_MODELS.items():
234
- loaded_asr[name] = load_asr(repo)
235
- print("All ASR models loaded.")
236
 
237
 
238
  @spaces.GPU
239
  def transcribe(audio, model_name):
240
- import soundfile as sf
241
-
242
  if audio is None:
243
- return "No audio provided"
244
  t0 = time.perf_counter()
245
 
246
- # Load audio as numpy float32 at 16kHz
247
  if isinstance(audio, str):
248
  wav, sr = sf.read(audio)
249
  wav = np.array(wav, dtype=np.float32)
@@ -261,7 +94,7 @@ def transcribe(audio, model_name):
261
  if sr != 16000:
262
  wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
263
  else:
264
- return "Unsupported audio format"
265
 
266
  wav = wav[:int(10.0 * 16000)]
267
  mel = compute_mel(wav)
@@ -314,7 +147,8 @@ with gr.Blocks(css=CSS, theme=theme, title="SozKZ ASR") as demo:
314
  gr.Markdown("Max 10 seconds. WAV/MP3/FLAC, 16kHz mono recommended.")
315
 
316
  gr.HTML("""<div style="text-align:center;padding:20px;font-size:12px;color:#aaa">
317
- <a href="https://huggingface.co/stukenov/sozkz-core-omniaudio-70m-kk-asr-v1" style="color:#888">Model</a> |
 
318
  <a href="https://huggingface.co/spaces/stukenov/sozkz-kazakh-llm-demo" style="color:#888">LLM Demo</a> |
319
  <a href="https://huggingface.co/stukenov" style="color:#888">stukenov</a>
320
  </div>""")
 
1
+ """SozKZ -- Kazakh ASR Demo. Uses original model_v2.py from HF repo."""
 
 
 
2
 
3
  import os
 
4
  import spaces
5
  import gradio as gr
6
  import torch
 
 
7
  import numpy as np
8
  import librosa
9
+ import soundfile as sf
10
  import time
11
  from transformers import PreTrainedTokenizerFast
12
  from huggingface_hub import hf_hub_download, login
 
15
  if HF_TOKEN:
16
  login(token=HF_TOKEN)
17
 
18
+ # Download and import original model code from HF repo
19
+ model_code_path = hf_hub_download("stukenov/sozkz-core-omniaudio-70m-kk-asr-v1", "src/model_v2.py")
20
+ import importlib.util
21
+ spec = importlib.util.spec_from_file_location("model_v2", model_code_path)
22
+ model_v2 = importlib.util.module_from_spec(spec)
23
+ spec.loader.exec_module(model_v2)
24
+
25
+ # Exact mel filterbank from torchaudio (pre-computed, diff=0.0)
26
+ MEL_FB = torch.load(
27
+ hf_hub_download("stukenov/sozkz-core-omniaudio-70m-kk-asr-v2", "mel_filterbank.pt"),
28
+ map_location="cpu", weights_only=True,
29
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  MEL_WINDOW = torch.hann_window(400)
31
 
32
+ def compute_mel(wav_np):
 
33
  wav = torch.from_numpy(wav_np).float()
34
+ stft = torch.stft(wav, n_fft=400, hop_length=160, win_length=400,
35
  window=MEL_WINDOW, center=True, pad_mode="reflect", return_complex=True)
36
+ power = stft.abs().pow(2)
37
+ mel = torch.matmul(MEL_FB.T, power)
38
+ return torch.log(torch.clamp(mel, min=1e-10)).unsqueeze(0)
39
 
40
 
41
+ # Load models
 
42
  ASR_MODELS = {
43
  "v2 (CTC+CE)": "stukenov/sozkz-core-omniaudio-70m-kk-asr-v2",
44
  "v1 (pure CE)": "stukenov/sozkz-core-omniaudio-70m-kk-asr-v1",
45
  }
46
+ ENC_CFG = {"n_mels": 80, "d_model": 256, "n_heads": 4, "n_layers": 6, "n_conv": 2}
47
+ DEC_CFG = {"d_model": 512, "n_heads": 8, "n_layers": 8}
48
 
49
+ TOK_REPO = "stukenov/sozkz-core-gpt2-50k-kk-base-v1"
50
  tok_file = hf_hub_download(TOK_REPO, "tokenizer.json")
51
  tokenizer = PreTrainedTokenizerFast(tokenizer_file=tok_file)
52
  tokenizer.eos_token = "<|endoftext|>"
53
  tokenizer.eos_token_id = 0
54
 
55
+ loaded_asr = {}
56
+ for name, repo in ASR_MODELS.items():
57
+ print(f"Loading {name} from {repo}...")
58
+ mdl = model_v2.OmniAudioScratchModel(
59
+ encoder_config=ENC_CFG, decoder_config=DEC_CFG, vocab_size=50257,
 
60
  )
61
  w = hf_hub_download(repo, "model.pt")
62
  sd = torch.load(w, map_location="cpu", weights_only=True)
63
+ info = mdl.load_state_dict(sd, strict=False)
64
+ print(f" missing: {len(info.missing_keys)}, unexpected: {len(info.unexpected_keys)}")
65
+ for k in info.missing_keys:
66
+ if "rope" not in k and "inv_freq" not in k:
67
+ print(f" MISSING: {k}")
68
+ mdl.requires_grad_(False)
69
+ loaded_asr[name] = mdl
70
+ print("Ready.")
 
 
 
71
 
72
 
73
  @spaces.GPU
74
  def transcribe(audio, model_name):
 
 
75
  if audio is None:
76
+ return "No audio"
77
  t0 = time.perf_counter()
78
 
79
+ # Load and resample to 16kHz mono
80
  if isinstance(audio, str):
81
  wav, sr = sf.read(audio)
82
  wav = np.array(wav, dtype=np.float32)
 
94
  if sr != 16000:
95
  wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
96
  else:
97
+ return "Unsupported format"
98
 
99
  wav = wav[:int(10.0 * 16000)]
100
  mel = compute_mel(wav)
 
147
  gr.Markdown("Max 10 seconds. WAV/MP3/FLAC, 16kHz mono recommended.")
148
 
149
  gr.HTML("""<div style="text-align:center;padding:20px;font-size:12px;color:#aaa">
150
+ <a href="https://huggingface.co/stukenov/sozkz-core-omniaudio-70m-kk-asr-v2" style="color:#888">v2 Model</a> |
151
+ <a href="https://huggingface.co/stukenov/sozkz-core-omniaudio-70m-kk-asr-v1" style="color:#888">v1 Model</a> |
152
  <a href="https://huggingface.co/spaces/stukenov/sozkz-kazakh-llm-demo" style="color:#888">LLM Demo</a> |
153
  <a href="https://huggingface.co/stukenov" style="color:#888">stukenov</a>
154
  </div>""")