Biorrith's picture
Small refinements and new repo id
55fea91
raw
history blame
4.88 kB
import random
import os
import numpy as np
import torch
import gradio as gr
import spaces
from chatterbox.tts_turbo import ChatterboxTurboTTS
MODEL = ChatterboxTurboTTS.from_pretrained("cuda" )
CUSTOM_CSS = """
.tag-container {
display: flex !important;
flex-wrap: wrap !important;
gap: 8px !important;
margin-top: 5px !important;
margin-bottom: 10px !important;
border: none !important;
background: transparent !important;
}
.tag-btn {
min-width: fit-content !important;
width: auto !important;
height: 32px !important;
font-size: 13px !important;
background: #eef2ff !important;
border: 1px solid #c7d2fe !important;
color: #3730a3 !important;
border-radius: 6px !important;
padding: 0 10px !important;
margin: 0 !important;
box-shadow: none !important;
}
.tag-btn:hover {
background: #c7d2fe !important;
transform: translateY(-1px);
}
"""
INSERT_TAG_JS = """
(tag_val, current_text) => {
const textarea = document.querySelector('#main_textbox textarea');
if (!textarea) return current_text + " " + tag_val;
const start = textarea.selectionStart;
const end = textarea.selectionEnd;
let prefix = " ";
let suffix = " ";
if (start === 0) prefix = "";
else if (current_text[start - 1] === ' ') prefix = "";
if (end < current_text.length && current_text[end] === ' ') suffix = "";
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
}
"""
def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
@spaces.GPU
def generate(
text,
audio_prompt_path,
temperature,
seed_num,
top_p,
top_k,
repetition_penalty,
norm_loudness
):
if seed_num != 0:
set_seed(int(seed_num))
wav = MODEL.generate(
text,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=repetition_penalty,
norm_loudness=norm_loudness,
)
return (MODEL.sr, wav.squeeze(0).cpu().numpy())
VOICE_OPTIONS = {
"mic": "samples/mic_trimmed.wav",
"nic": "samples/nic_trimmed.wav"
}
def update_ref_audio(voice_name):
return VOICE_OPTIONS.get(voice_name, list(VOICE_OPTIONS.values())[0])
with gr.Blocks(title="Chatterbox Turbo") as demo:
gr.Markdown(
"""
# ⚡ Røst v3 Chatterbox 350m
Generate high-quality danish speech from text with reference audio styling. This is model was developed as part of the CoRal project, and is a finetuned version of Chatterbox Multilingual.
""")
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="København er Danmarks hovedstad og ligger på øerne Sjælland og Amager, hvor mange turister besøger de smukke kanaler og historiske bygninger.",
label="Text to synthesize (max chars 300)",
max_lines=5,
elem_id="main_textbox"
)
voice = gr.Dropdown(
choices=list(VOICE_OPTIONS.keys()),
value="mic",
label="Voice Selection",
info="Choose a voice or upload your own below"
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File",
value=VOICE_OPTIONS["mic"],
)
run_btn = gr.Button("Generate ⚡", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
with gr.Accordion("Advanced Options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.7)
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=600)
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
voice.change(
fn=update_ref_audio,
inputs=[voice],
outputs=[ref_wav],
)
run_btn.click(
fn=generate,
inputs=[
text,
ref_wav,
temp,
seed_num,
top_p,
top_k,
repetition_penalty,
norm_loudness,
],
outputs=audio_output,
)
if __name__ == "__main__":
demo.queue().launch(
mcp_server=True,
css=CUSTOM_CSS,
ssr_mode=False
)