stukenov's picture
Upload app.py with huggingface_hub
b9cb7dc verified
raw
history blame
6.46 kB
"""SozKZ -- Kazakh ASR Demo. Uses original model_v2.py from HF repo."""
import os
import spaces
import gradio as gr
import torch
import numpy as np
import librosa
import soundfile as sf
import time
from transformers import PreTrainedTokenizerFast
from huggingface_hub import hf_hub_download, login
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# Download and import original model code from HF repo
model_code_path = hf_hub_download("stukenov/sozkz-core-omniaudio-70m-kk-asr-v1", "src/model_v2.py")
import importlib.util
spec = importlib.util.spec_from_file_location("model_v2", model_code_path)
model_v2 = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_v2)
# Exact mel filterbank from torchaudio (pre-computed, diff=0.0)
MEL_FB = torch.load(
hf_hub_download("stukenov/sozkz-core-omniaudio-70m-kk-asr-v2", "mel_filterbank.pt"),
map_location="cpu", weights_only=True,
)
MEL_WINDOW = torch.hann_window(400)
def compute_mel(wav_np):
wav = torch.from_numpy(wav_np).float()
stft = torch.stft(wav, n_fft=400, hop_length=160, win_length=400,
window=MEL_WINDOW, center=True, pad_mode="reflect", return_complex=True)
power = stft.abs().pow(2)
mel = torch.matmul(MEL_FB.T, power)
return torch.log(torch.clamp(mel, min=1e-10)).unsqueeze(0)
# Load models
ASR_MODELS = {
"v2 (CTC+CE)": "stukenov/sozkz-core-omniaudio-70m-kk-asr-v2",
"v1 (pure CE)": "stukenov/sozkz-core-omniaudio-70m-kk-asr-v1",
}
ENC_CFG = {"n_mels": 80, "d_model": 256, "n_heads": 4, "n_layers": 6, "n_conv": 2}
DEC_CFG = {"d_model": 512, "n_heads": 8, "n_layers": 8}
TOK_REPO = "stukenov/sozkz-core-gpt2-50k-kk-base-v1"
tok_file = hf_hub_download(TOK_REPO, "tokenizer.json")
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tok_file)
tokenizer.eos_token = "<|endoftext|>"
tokenizer.eos_token_id = 0
loaded_asr = {}
for name, repo in ASR_MODELS.items():
print(f"Loading {name} from {repo}...")
mdl = model_v2.OmniAudioScratchModel(
encoder_config=ENC_CFG, decoder_config=DEC_CFG, vocab_size=50257,
)
w = hf_hub_download(repo, "model.pt")
sd = torch.load(w, map_location="cpu", weights_only=True)
info = mdl.load_state_dict(sd, strict=False)
# lm_head not in checkpoint — it's tied to embed_tokens
mdl.lm_head.weight = mdl.embed_tokens.weight
print(f" missing: {len(info.missing_keys)}, unexpected: {len(info.unexpected_keys)}, lm_head tied")
for k in info.missing_keys:
if "rope" not in k and "inv_freq" not in k and "lm_head" not in k:
print(f" MISSING: {k}")
mdl.requires_grad_(False)
loaded_asr[name] = mdl
print("Ready.")
@spaces.GPU
def transcribe(audio, model_name):
if audio is None:
return "No audio"
t0 = time.perf_counter()
# Load and resample to 16kHz mono
if isinstance(audio, str):
wav, sr = sf.read(audio)
wav = np.array(wav, dtype=np.float32)
if wav.ndim > 1:
wav = wav.mean(axis=-1)
if sr != 16000:
wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
elif isinstance(audio, tuple):
sr, wav = audio
wav = np.array(wav, dtype=np.float32)
if wav.ndim > 1:
wav = wav.mean(axis=-1) if wav.shape[-1] <= 2 else wav.mean(axis=0)
if np.abs(wav).max() > 1.0:
wav = wav / 32768.0
if sr != 16000:
wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
else:
return "Unsupported format"
wav = wav[:int(10.0 * 16000)]
mel = compute_mel(wav)
asr = loaded_asr.get(model_name, loaded_asr["v2 (CTC+CE)"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
asr.to(device)
mel = mel.to(device)
with torch.no_grad():
tokens = asr.generate(mel, max_new_tokens=100, eos_token_id=0, repetition_penalty=1.5)
asr.to("cpu")
text = tokenizer.decode(tokens, skip_special_tokens=True).strip()
elapsed = time.perf_counter() - t0
if not text:
return f"(no speech detected, {elapsed:.1f}s)"
return text + f"\n\n({elapsed:.1f}s)"
CSS = """
:root, :root.dark { color-scheme: light only !important; --body-background-fill: #fff !important; }
html, body { background: #fff !important; }
* { font-family: -apple-system, BlinkMacSystemFont, "SF Pro Text", sans-serif !important; }
.gradio-container[class] { max-width: 700px !important; margin: 0 auto !important; }
footer { display: none !important; }
#hero { padding: 28px 32px 16px; text-align: center; }
#hero h1 { font-size: 36px !important; font-weight: 700 !important; margin: 0 0 6px !important; }
#hero .sub { font-size: 16px; color: #86868b; margin: 0 0 12px; }
#hero .tag { font-size: 12px; color: #6e6e73; background: #f5f5f7; border: 1px solid #d2d2d7; padding: 4px 12px; border-radius: 100px; }
.output textarea { font-size: 20px !important; line-height: 1.6 !important; }
"""
theme = gr.themes.Base(primary_hue="blue")
theme.set(body_background_fill="#fff", block_background_fill="#fff", block_border_width="0px", block_shadow="none")
with gr.Blocks(css=CSS, theme=theme, title="SozKZ ASR") as demo:
gr.HTML("""<div id="hero">
<h1>SozKZ ASR</h1>
<p class="sub">Kazakh Speech Recognition</p>
<span class="tag">OmniAudio 70M</span>
</div>""")
model_sel = gr.Radio(list(ASR_MODELS.keys()), value="v2 (CTC+CE)", label="Model", interactive=True)
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (WAV/MP3/FLAC, max 10s)")
btn = gr.Button("Transcribe", variant="primary")
output = gr.Textbox(label="Transcription", lines=4, elem_classes=["output"])
btn.click(fn=transcribe, inputs=[audio_input, model_sel], outputs=output)
gr.Markdown("Max 10 seconds. WAV/MP3/FLAC, 16kHz mono recommended.")
gr.HTML("""<div style="text-align:center;padding:20px;font-size:12px;color:#aaa">
<a href="https://huggingface.co/stukenov/sozkz-core-omniaudio-70m-kk-asr-v2" style="color:#888">v2 Model</a> |
<a href="https://huggingface.co/stukenov/sozkz-core-omniaudio-70m-kk-asr-v1" style="color:#888">v1 Model</a> |
<a href="https://huggingface.co/spaces/stukenov/sozkz-kazakh-llm-demo" style="color:#888">LLM Demo</a> |
<a href="https://huggingface.co/stukenov" style="color:#888">stukenov</a>
</div>""")
if __name__ == "__main__":
demo.launch(ssr_mode=False, show_error=True)