Yehor commited on
Commit
9720263
·
1 Parent(s): 16b04a0

Improve the code

Browse files
Files changed (3) hide show
  1. .gitignore +9 -0
  2. app.py +140 -21
  3. requirements.txt +1 -1
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .venv
3
+ .ruff_cache
4
+
5
+ rttm_outputs
6
+ vad_frame_outputs
7
+
8
+ tmp.json
9
+ tmp.zip
app.py CHANGED
@@ -1,14 +1,18 @@
1
  import os
 
2
  from os.path import basename
3
  from zipfile import ZipFile, ZIP_DEFLATED
4
  from shutil import rmtree
 
5
 
 
6
  import sentry_sdk
7
- import sphn
8
  import gradio as gr
 
 
9
  from inference import inference_file
10
 
11
- if os.environ["SENTRY_DSN"]:
12
  sentry_sdk.init(
13
  dsn=os.environ["SENTRY_DSN"],
14
  send_default_pii=True,
@@ -16,21 +20,76 @@ if os.environ["SENTRY_DSN"]:
16
  print("Sentry SDK is activated")
17
 
18
 
19
- def extract_chunks(file, min_sec, max_sec):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  archive_name = "tmp.zip"
21
  n_files = 0
22
  duration_secs = 0
23
 
24
  with ZipFile(
25
- archive_name, "w", compression=ZIP_DEFLATED, allowZip64=True, compresslevel=9
 
 
 
 
26
  ) as zip_file:
27
- results = inference_file(file)
28
-
29
- filenames = [it["filename"] for it in results]
30
- durations = sphn.durations(filenames)
 
31
 
32
  for idx, result in enumerate(results):
33
- duration = durations[idx]
34
 
35
  print(result, duration)
36
 
@@ -56,15 +115,75 @@ def extract_chunks(file, min_sec, max_sec):
56
  return archive_name
57
 
58
 
59
- demo = gr.Interface(
60
- title="MarbleNet",
61
- fn=extract_chunks,
62
- inputs=[
63
- gr.File(label="WAV file to process", file_count="single", file_types=[".wav"]),
64
- gr.Number(label="Minimum seconds", value=0.1, minimum=0.01, maximum=59.99),
65
- gr.Number(label="Maximum seconds", value=30, minimum=0.02, maximum=60),
66
- ],
67
- outputs=[gr.File(label="ZIP file with voice chunks")],
68
- submit_btn="Inference",
69
- )
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
  from os.path import basename
4
  from zipfile import ZipFile, ZIP_DEFLATED
5
  from shutil import rmtree
6
+ from importlib.metadata import version
7
 
8
+ import torch
9
  import sentry_sdk
 
10
  import gradio as gr
11
+ from gradio.themes import Soft
12
+
13
  from inference import inference_file
14
 
15
+ if "SENTRY_DSN" in os.environ:
16
  sentry_sdk.init(
17
  dsn=os.environ["SENTRY_DSN"],
18
  send_default_pii=True,
 
20
  print("Sentry SDK is activated")
21
 
22
 
23
+ use_cuda = torch.cuda.is_available()
24
+
25
+ if use_cuda:
26
+ print("CUDA is available, setting correct inference_device variable.")
27
+ device = "cuda"
28
+ else:
29
+ device = "cpu"
30
+
31
+ # https://www.tablesgenerator.com/markdown_tables
32
+ authors_table = """
33
+ ## Authors
34
+ Follow them in social networks and **contact** if you need any help or have any questions:
35
+ | **Yehor Smoliakov** |
36
+ |-------------------------------------------------------------------------------------------------|
37
+ | https://t.me/smlkw in Telegram |
38
+ | https://x.com/yehor_smoliakov at X |
39
+ | https://github.com/egorsmkv at GitHub |
40
+ | https://huggingface.co/Yehor at Hugging Face |
41
+ | or use egorsmkv@gmail.com |
42
+ """.strip()
43
+
44
+ tech_env = f"""
45
+ #### Environment
46
+ - Python: {sys.version}
47
+ - Torch device: {device}
48
+
49
+ #### Models
50
+
51
+ ##### Acoustic model (Voice Activity Detection)
52
+ - Name: MarbleNet
53
+ - URL: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_marblenet
54
+ """.strip()
55
+
56
+ tech_libraries = f"""
57
+ #### Libraries
58
+ - torch: {version("torch")}
59
+ - sphn: {version("sphn")}
60
+ - gradio: {version("gradio")}
61
+ - sentry_sdk: {version("sentry_sdk")}
62
+ """.strip()
63
+
64
+ description_head = """
65
+ # MarbleNet
66
+
67
+ Split an audio file to voice chunks.
68
+ """.strip()
69
+
70
+ concurrency_limit = 5
71
+
72
+
73
+ def extract_chunks(wav_file, min_sec, max_sec):
74
  archive_name = "tmp.zip"
75
  n_files = 0
76
  duration_secs = 0
77
 
78
  with ZipFile(
79
+ archive_name,
80
+ "w",
81
+ compression=ZIP_DEFLATED,
82
+ allowZip64=True,
83
+ compresslevel=9,
84
  ) as zip_file:
85
+ try:
86
+ results = inference_file(wav_file)
87
+ except Exception as exc:
88
+ sentry_sdk.capture_exception(exc)
89
+ raise gr.Error("Something went wrong, we will be notified about this")
90
 
91
  for idx, result in enumerate(results):
92
+ duration = result["speech"]["duration"]
93
 
94
  print(result, duration)
95
 
 
115
  return archive_name
116
 
117
 
118
+ def create_app():
119
+ tab = gr.Blocks(
120
+ title="MarbleNet",
121
+ analytics_enabled=False,
122
+ theme=Soft(),
123
+ )
124
+
125
+ with tab:
126
+ gr.Markdown(description_head)
127
+
128
+ gr.Markdown("## Usage")
129
+
130
+ with gr.Column():
131
+ wav_file = gr.File(
132
+ label="WAV file to process",
133
+ file_count="single",
134
+ file_types=[".wav"],
135
+ )
136
+ min_sec = gr.Number(
137
+ label="Minimum seconds", value=0.1, minimum=0.01, maximum=59.99
138
+ )
139
+ max_sec = gr.Number(
140
+ label="Maximum seconds", value=30, minimum=0.02, maximum=60
141
+ )
142
+
143
+ with gr.Column():
144
+ zip_file = gr.File(label="ZIP file with voice chunks")
145
+
146
+ gr.Button("Run").click(
147
+ extract_chunks,
148
+ concurrency_limit=concurrency_limit,
149
+ inputs=[wav_file, min_sec, max_sec],
150
+ outputs=[zip_file],
151
+ )
152
+
153
+ return tab
154
+
155
+
156
+ def create_env():
157
+ with gr.Blocks(theme=Soft()) as tab:
158
+ gr.Markdown(tech_env)
159
+ gr.Markdown(tech_libraries)
160
+
161
+ return tab
162
+
163
+
164
+ def create_authors():
165
+ with gr.Blocks(theme=Soft()) as tab:
166
+ gr.Markdown(authors_table)
167
+
168
+ return tab
169
+
170
+
171
+ def create_demo():
172
+ app_tab = create_app()
173
+ authors_tab = create_authors()
174
+ env_tab = create_env()
175
+
176
+ return gr.TabbedInterface(
177
+ [app_tab, authors_tab, env_tab],
178
+ tab_names=[
179
+ "🎙️ VAD",
180
+ "👥 Authors",
181
+ "📦 Environment, Models, and Libraries",
182
+ ],
183
+ )
184
+
185
+
186
+ if __name__ == "__main__":
187
+ demo = create_demo()
188
+ demo.queue()
189
+ demo.launch()
requirements.txt CHANGED
@@ -11,4 +11,4 @@ plotly
11
 
12
  gradio
13
 
14
- sentry-sdk
 
11
 
12
  gradio
13
 
14
+ sentry-sdk[huggingface_hub]