vad-marblenet / app.py
Yehor's picture
ruff format
59e958d
import os
import sys
from os import rename
from os.path import basename
from zipfile import ZipFile, ZIP_DEFLATED
from shutil import rmtree
from importlib.metadata import version
import sphn
import torch
import sentry_sdk
import gradio as gr
from gradio.themes import Soft
from inference import inference_file
if "SENTRY_DSN" in os.environ:
sentry_sdk.init(
dsn=os.environ["SENTRY_DSN"],
send_default_pii=True,
)
print("Sentry SDK is activated")
use_cuda = torch.cuda.is_available()
if use_cuda:
print("CUDA is available, setting correct device variable.")
device = "cuda"
else:
device = "cpu"
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them in social networks and **contact** if you need any help or have any questions:
| **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use egorsmkv@gmail.com |
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
- Torch device: {device}
#### Models
##### Acoustic model (Voice Activity Detection)
- Name: MarbleNet
- URL: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_marblenet
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version("torch")}
- sphn: {version("sphn")}
- gradio: {version("gradio")}
- sentry_sdk: {version("sentry_sdk")}
""".strip()
description_head = """
# MarbleNet
Split an audio file to voice chunks.
""".strip()
concurrency_limit = 5
def inference_func(wav_file, min_sec, max_sec):
archive_name = "tmp.zip"
n_files = 0
duration_secs = 0
# Validate the file
try:
data, sr = sphn.read(wav_file)
duration = len(data[0]) / sr
if duration < 0.1:
raise gr.Error("The duration is too low")
n_channels = len(data)
if n_channels > 1:
raise gr.Error(
f"Your file must be in the mono format. The file has {n_channels} channels."
)
except Exception as e:
raise gr.Error(f"Can't read your file, the problem: {e}")
# Rename the file
old_wav_file = wav_file
wav_file = "input.wav"
rename(old_wav_file, wav_file)
with ZipFile(
archive_name,
"w",
compression=ZIP_DEFLATED,
allowZip64=True,
compresslevel=9,
) as zip_file:
try:
results = inference_file(wav_file)
except Exception as exc:
sentry_sdk.capture_exception(exc)
raise gr.Error("Something went wrong, we will be notified about this")
for idx, result in enumerate(results):
duration = result["speech"]["duration"]
print(result, duration)
if duration <= min_sec or duration >= max_sec:
print("Skipping...")
continue
arc_name = basename(result["filename"])
zip_file.write(result["filename"], arc_name)
duration_secs += duration
n_files += 1
# Remove files
rmtree("chunks")
mins = round(duration_secs / 60, 4)
gr.Success(
f"VAD model identified {n_files} files in interval [{min_sec}:{max_sec}], total duration = {mins} min."
)
return archive_name
def create_app():
tab = gr.Blocks(
title="MarbleNet",
analytics_enabled=False,
theme=Soft(),
)
with tab:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Column():
wav_file = gr.File(
label="WAV file to process",
file_count="single",
file_types=[".wav"],
)
min_sec = gr.Number(
label="Minimum seconds", value=0.1, minimum=0.01, maximum=59.99
)
max_sec = gr.Number(
label="Maximum seconds", value=30, minimum=0.02, maximum=60
)
with gr.Column():
zip_file = gr.File(label="ZIP file with voice chunks")
gr.Button("Run").click(
inference_func,
concurrency_limit=concurrency_limit,
inputs=[wav_file, min_sec, max_sec],
outputs=[zip_file],
)
return tab
def create_env():
with gr.Blocks(theme=Soft()) as tab:
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
return tab
def create_authors():
with gr.Blocks(theme=Soft()) as tab:
gr.Markdown(authors_table)
return tab
def create_demo():
app_tab = create_app()
authors_tab = create_authors()
env_tab = create_env()
return gr.TabbedInterface(
[app_tab, authors_tab, env_tab],
tab_names=[
"🎙️ VAD",
"👥 Authors",
"📦 Environment, Models, and Libraries",
],
)
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch()