Liangjiejie commited on
Commit
d8c96d3
·
verified ·
1 Parent(s): 724b57f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -103
app.py CHANGED
@@ -1,121 +1,176 @@
1
  import gradio as gr
2
- import torch
3
- import scipy.io.wavfile as wav
4
- import numpy as np
5
  import os
6
- import logging
7
- from pypinyin import Style, pinyin
8
- import jieba
9
-
10
- # ----------------------------------------------------
11
- # ❗ 注意:为了运行此代码,您需要将 BERT-VITS2 的
12
- # 推理逻辑(如 text_to_sequence 和 g2p 模块)
13
- # 复制到您的 Space 中。由于无法直接提供,这里假设
14
- # 您已手动下载并添加到名为 'custom_modules' 的文件夹。
15
- # ----------------------------------------------------
16
-
17
- # 配置日志
18
- logging.basicConfig(level=logging.INFO)
19
-
20
- # --- 1. 模型初始化 ---
21
- # 警告:此 ID 仅用于占位,您需要手动上传 BERT-VITS2 的
22
- # .pth 权重文件和 config.json 文件到您的 Space 中。
23
- # 这里无法使用 from_pretrained 方法!
24
-
25
- # **您需要手动上传的模型文件到 Space 根目录**
26
- MODEL_NAME = "model_best"
27
- MODEL_PATH = f"logs/44k/{MODEL_NAME}.pth"
28
- CONFIG_PATH = "configs/bert_vits2_ljs_44k.json"
29
-
30
- # 假设模型和配置已经存在(需要您手动上传)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
- # 模拟加载 BERT-VITS2 所需的定制化组件
33
- # 这一步在实际部署中将失败,因为缺少 BERT-VITS2 库文件
34
- # 真正的 BERT-VITS2 部署需要将整个推理代码克隆下来
35
 
36
- # 暂时使用占位符,如果模型加失败,则退回纯文状态
37
- class DummyModel:
38
- def __init__(self):
39
- self.config = type('Config', (object,), {'sampling_rate': 44100})
40
- logging.error("使用占位模型!请手上传 BERT-VITS2 。")
41
-
42
- def generate(self, *args, **kwargs):
43
- # 模拟生成一个 2 秒的静音音频
44
- sample_rate = 44100
45
- duration = 2
46
- t = np.linspace(0, duration, int(sample_rate * duration), False)
47
- audio = 0.5 * np.sin(2 * np.pi * 440 * t)
48
-
49
- # 模拟模型输出
50
- class Output:
51
- def __init__(self, waveform):
52
- self.waveform = torch.tensor(waveform).unsqueeze(0).float()
53
-
54
- return Output(audio)
55
-
56
- model = DummyModel()
57
- tokenizer = None
58
- device = "cpu"
59
- logging.info("模型加载状态:当前为 BERT-VITS2 占位模式。")
60
 
 
 
 
 
61
  except Exception as e:
62
- logging.error(f"BERT-VITS2 复杂加载失败,错误: {e}")
 
63
  model = None
64
- tokenizer = None
 
 
 
 
 
 
 
 
65
 
66
- # --- 2. 核心 TTS 函数 (使用拼音和分词) ---
67
- def generate_speech(text):
68
- if not text or model is None:
69
- return None, "❌ 模型未加载或文本为空。请检查 Logs 并上 BERT-VITS2 文件。"
 
 
 
 
 
 
 
 
70
 
71
  try:
72
- # 1. 中文分词和预处理 (BERT-VITS2需要复杂的素处理)
73
- processed_text = " ".join(jieba.cut(text))
74
- logging.info(f"使用 BERT-VITS2 预处理文本: {processed_text[:60]}...")
75
-
76
- # 2. 调用模型生成 (此处是占位符,实际应调用 BERT-VITS2.synthesize)
77
- output = model.generate(text=processed_text)
78
-
79
- # 3. 提取音频波形
80
- speech = output.waveform.squeeze(0).cpu().numpy()
81
- sampling_rate = model.config.sampling_rate
82
 
83
- # 4. 保存为临时的WAV文件
84
- output_file_path = "output_speech.wav"
85
- speech_int16 = (speech * 32767).astype(np.int16)
86
 
87
- wav.write(output_file_path, sampling_rate, speech_int16)
88
 
89
- logging.info("语音生成占位成功。")
90
- return output_file_path, f"⚠️ 占位模型成功运行。请上传 BERT-VITS2 文件以获得真人效果。"
91
-
92
  except Exception as e:
93
- logging.error(f"语音生成过程中发生错误: {e}")
94
- return None, f"❌ 语音生成过程中发生错误:{e}"
95
 
96
- # --- 3. Gradio 界面 (保持不变) ---
97
- with gr.Blocks() as demo:
98
- gr.Markdown("# 🏆 BERT-VITS2 中文 TTS 部署 (需手动上传文件)")
99
- gr.Markdown("**注意:这是目前开源领域效果最好的模型架构,但需要您手动上传模型权重文件!**")
100
- gr.Markdown(f"当前模型状态:{'已加载占位模型' if model else '模型加载失败'}")
101
 
102
- # 定义所有组件变量
103
- text_input = gr.Textbox(lines=5, label="输入中文文本", placeholder="你好,这是 BERT-VITS2 部署。效果应该非常自然!")
104
- audio_output = gr.Audio(label="生成的语音", type="filepath")
105
- status_text = gr.Textbox(label="状态信息", interactive=False)
106
- generate_btn = gr.Button("🚀 开始生成语音")
 
 
107
 
108
- # 定义布局
 
 
 
 
 
 
109
  with gr.Row():
110
- gr.Column(text_input, scale=3)
111
- with gr.Column(scale=2):
112
- audio_output
113
- status_text
114
- generate_btn
115
-
116
- # 绑定按钮和函数
117
- generate_btn.click(fn=generate_speech,
118
- inputs=text_input,
119
- outputs=[audio_output, status_text])
120
-
121
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import requests
 
 
3
  import os
4
+ import json
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer, AutoModelForTextToSpeech, AutoProcessor, AutoConfig
7
+ from transformers.utils import is_flash_attn_available
8
+
9
+ # 仓库信息
10
+ REPO_ID = "BricksDisplay/ellie-Bert-VITS2"
11
+ FILES_TO_DOWNLOAD = [
12
+ "model.safetensors",
13
+ "config.json",
14
+ "configuration_bert_vits2.py",
15
+ "modeling_bert_vits2.py",
16
+ "preprocessor_config.json",
17
+ "processing_bert_vits2.py",
18
+ "processor_config.json",
19
+ "special_tokens_map.json",
20
+ "tokenizer_bert_vits2.py",
21
+ "tokenizer.json",
22
+ "tokenizer_config.json",
23
+ # 文件夹内容
24
+ "bert_zh/config.json",
25
+ "bert_zh/tokenizer_config.json",
26
+ "bert_zh/vocab.txt",
27
+ "data/pinyin.json",
28
+ "data/symbols.json",
29
+ "onnx/config.json",
30
+ "onnx/model_index.json",
31
+ "onnx/tokenizer_config.json",
32
+ ]
33
+
34
+ def download_file(file_path):
35
+ """从 Hugging Face CDN 下载文件到本地。"""
36
+ if os.path.exists(file_path):
37
+ print(f"文件已存在: {file_path}")
38
+ return
39
+
40
+ # 构造下载链接
41
+ url = f"https://huggingface.co/{REPO_ID}/resolve/main/{file_path}"
42
+
43
+ # 确保文件夹存在
44
+ os.makedirs(os.path.dirname(file_path) or '.', exist_ok=True)
45
+
46
+ print(f"正在下载: {file_path}")
47
+
48
+ try:
49
+ response = requests.get(url, stream=True)
50
+ response.raise_for_status() # 检查是否有 HTTP 错误
51
+
52
+ total_size = int(response.headers.get('content-length', 0))
53
+ block_size = 1024 # 1 Kibibyte
54
+
55
+ with open(file_path, 'wb') as f, tqdm(
56
+ desc=file_path,
57
+ total=total_size,
58
+ unit='iB',
59
+ unit_scale=True,
60
+ unit_divisor=1024,
61
+ ) as bar:
62
+ for data in response.iter_content(block_size):
63
+ bar.update(len(data))
64
+ f.write(data)
65
+
66
+ print(f"下载完成: {file_path}")
67
+ except Exception as e:
68
+ print(f"下载 {file_path} 失败: {e}")
69
+ # 如果是文件夹路径,确保创建它
70
+ if file_path.endswith('/') or '.' not in file_path.split('/')[-1]:
71
+ os.makedirs(file_path, exist_ok=True)
72
+ else:
73
+ raise Exception(f"无法下载文件: {file_path}")
74
+
75
+ def ensure_files_exist():
76
+ """检查并下载所有必需的文件。"""
77
+ print("开始检查模型文件...")
78
+ for file in FILES_TO_DOWNLOAD:
79
+ download_file(file)
80
+
81
+ # 确保文件存在
82
+ ensure_files_exist()
83
+
84
+ # --- 模型加载和推理 ---
85
  try:
86
+ # 尝试使用 flash_attn, 仅在支持时
87
+ attn_implementation = "flash_attention_2" if is_flash_attn_available() else "eager"
 
88
 
89
+ # 由于文件已下载到本地,直接从本地加载(路径为 '.')
90
+ model = AutoModelForTextToSpeech.from_pretrained(
91
+ ".",
92
+ attn_implementation=attn_implementation,
93
+ device_map="auto" # 自映射到 CPU (您的 Space 硬件)
94
+ )
95
+ processor = AutoProcessor.from_pretrained(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # 设定默认说话人ID
98
+ DEFAULT_SPEAKER_ID = 0
99
+ print("模型加载成功!")
100
+
101
  except Exception as e:
102
+ print(f"模型加载失败,请检查下载的文件是否完整: {e}")
103
+ # 设置一个假模型以避免 Gradio 启动失败
104
  model = None
105
+ processor = None
106
+ DEFAULT_SPEAKER_ID = 0
107
+
108
+ def tts_generate(text):
109
+ """
110
+ 文本转语音生成函数
111
+ """
112
+ if model is None or processor is None:
113
+ return None, "模型未加载成功,请检查日志或上传文件。"
114
 
115
+ if not text:
116
+ return None, "请输入中文文本。"
117
+
118
+ # 文本和说话人ID入处理器
119
+ inputs = processor(
120
+ text=text,
121
+ speaker_id=DEFAULT_SPEAKER_ID,
122
+ return_tensors="pt"
123
+ )
124
+
125
+ # 移到模型所在设备
126
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
127
 
128
  try:
129
+ # 生成语
130
+ with torch.no_grad():
131
+ output = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.9, temperature=0.7)
 
 
 
 
 
 
 
132
 
133
+ # 获取采样率
134
+ sampling_rate = model.config.sampling_rate
 
135
 
136
+ return (sampling_rate, output.squeeze().cpu().numpy()), "成功!"
137
 
 
 
 
138
  except Exception as e:
139
+ return None, f"生成语音失败: {e}"
 
140
 
141
+ # --- Gradio 界面 ---
142
+ import torch
143
+
144
+ # 使用 torch 库作为依赖
145
+ torch_version = torch.__version__
146
 
147
+ title = "免费中文文本转语音演示 (BERT-VITS2 模型)"
148
+ description = f"""
149
+ **当前模型:** BricksDisplay/ellie-Bert-VITS2 (目前中文语音效果最好的开源模型之一)
150
+ - **注意:** 由于您使用的是免费的 CPU Basic 硬件,模型较大(1.59 GB),语音生成会有较长延迟,这是正常的。
151
+ - **状态:** 模型和文件已在启动时自动下载并加载。
152
+ - **PyTorch 版本:** {torch_version}
153
+ """
154
 
155
+ with gr.Blocks() as demo:
156
+ gr.Markdown(f"# {title}")
157
+ gr.Markdown(description)
158
+
159
+ with gr.Row():
160
+ text_input = gr.Textbox(label="输入中文文本", placeholder="请输入您想合成的中文语句...")
161
+
162
  with gr.Row():
163
+ generate_button = gr.Button("🚀 开始生成语音")
164
+
165
+ with gr.Row():
166
+ audio_output = gr.Audio(label="生成的语音", type="numpy")
167
+ status_text = gr.Textbox(label="状态信息", value="模型等待输入...")
168
+
169
+ generate_button.click(
170
+ fn=tts_generate,
171
+ inputs=[text_input],
172
+ outputs=[audio_output, status_text]
173
+ )
174
+
175
+ # 启动 Gradio
176
+ demo.queue().launch()