| import subprocess |
| import tempfile |
| import os |
| import json |
| import shutil |
| import time |
| import librosa |
| import torch |
| import argparse |
| import soundfile as sf |
| from pathlib import Path |
| import cn2an |
| import requests |
| import re |
| import numpy as np |
| import onnxruntime as ort |
| import axengine as axe |
|
|
| |
| from model import SinusoidalPositionEncoder |
| from utils.ax_model_bin import AX_SenseVoiceSmall |
| from utils.ax_vad_bin import AX_Fsmn_vad |
| from utils.vad_utils import merge_vad |
| from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer |
|
|
| |
| from libmelotts.python.split_utils import split_sentence |
| from libmelotts.python.text import cleaned_text_to_sequence |
| from libmelotts.python.text.cleaner import clean_text |
| from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP |
|
|
| |
| |
| TTS_MODEL_DIR = "libmelotts/models" |
| TTS_MODEL_FILES = { |
| "g": "g-zh_mix_en.bin", |
| "encoder": "encoder-zh.onnx", |
| "decoder": "decoder-zh.axmodel" |
| } |
|
|
| |
| QWEN_API_URL = "http://10.126.29.158:8000" |
|
|
|
|
| |
| def intersperse(lst, item): |
| result = [item] * (len(lst) * 2 + 1) |
| result[1::2] = lst |
| return result |
|
|
|
|
| def get_text_for_tts_infer(text, language_str, symbol_to_id=None): |
| norm_text, phone, tone, word2ph = clean_text(text, language_str) |
| phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id) |
|
|
| phone = intersperse(phone, 0) |
| tone = intersperse(tone, 0) |
| language = intersperse(language, 0) |
|
|
| phone = np.array(phone, dtype=np.int32) |
| tone = np.array(tone, dtype=np.int32) |
| language = np.array(language, dtype=np.int32) |
| word2ph = np.array(word2ph, dtype=np.int32) * 2 |
| word2ph[0] += 1 |
|
|
| return phone, tone, language, norm_text, word2ph |
|
|
|
|
| def audio_numpy_concat(segment_data_list, sr, speed=1.): |
| audio_segments = [] |
| for segment_data in segment_data_list: |
| audio_segments += segment_data.reshape(-1).tolist() |
| audio_segments += [0] * int((sr * 0.05) / speed) |
| audio_segments = np.array(audio_segments).astype(np.float32) |
| return audio_segments |
|
|
|
|
| def merge_sub_audio(sub_audio_list, pad_size, audio_len): |
| |
| if pad_size > 0: |
| for i in range(len(sub_audio_list) - 1): |
| sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] |
| sub_audio_list[i][-pad_size:] /= 2 |
| if i > 0: |
| sub_audio_list[i] = sub_audio_list[i][pad_size:] |
|
|
| sub_audio = np.concatenate(sub_audio_list, axis=-1) |
| return sub_audio[:audio_len] |
|
|
|
|
| def calc_word2pronoun(word2ph, pronoun_lens): |
| indice = [0] |
| for ph in word2ph[:-1]: |
| indice.append(indice[-1] + ph) |
| word2pronoun = [] |
| for i, ph in zip(indice, word2ph): |
| word2pronoun.append(np.sum(pronoun_lens[i : i + ph])) |
| return word2pronoun |
|
|
|
|
| def generate_slices(word2pronoun, dec_len): |
| pn_start, pn_end = 0, 0 |
| zp_start, zp_end = 0, 0 |
| zp_len = 0 |
| pn_slices = [] |
| zp_slices = [] |
| while pn_end < len(word2pronoun): |
| |
| if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len: |
| zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end]) |
| zp_start = zp_end - zp_len |
| pn_start = pn_end - 2 |
| else: |
| zp_len = 0 |
| zp_start = zp_end |
| pn_start = pn_end |
| |
| while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len: |
| zp_len += word2pronoun[pn_end] |
| pn_end += 1 |
| zp_end = zp_start + zp_len |
| pn_slices.append(slice(pn_start, pn_end)) |
| zp_slices.append(slice(zp_start, zp_end)) |
| return pn_slices, zp_slices |
|
|
|
|
| |
| def lang_detect_with_regex(text): |
| """ |
| 语言识别 |
| """ |
| |
| text_without_digits = re.sub(r'\d+', '', text) |
| |
| if not text_without_digits: |
| return 'unknown' |
| |
| |
| if re.search(r'[\u4e00-\u9fff]', text_without_digits): |
| return 'chinese' |
| else: |
| |
| if re.search(r'[a-zA-Z]', text_without_digits): |
| return 'english' |
| else: |
| return 'unknown' |
|
|
| class QwenTranslationAPI: |
| def __init__(self, api_url=QWEN_API_URL): |
| self.api_url = api_url |
| self.session_id = f"speech_translate_{int(time.time())}" |
| |
| def translate(self, text_content, max_retries=3, timeout=120): |
| """调用千问API进行翻译""" |
| if not text_content or text_content.strip() == "": |
| return "输入文本为空" |
| |
| |
| if lang_detect_with_regex(text_content)=='chinese': |
| prompt_f = "翻译成英文" |
| else: |
| prompt_f= "翻译成中文" |
|
|
| prompt = f"{prompt_f}:{text_content}" |
| print(f"[翻译API] 发送请求: {prompt}") |
| |
| for attempt in range(max_retries): |
| try: |
| |
| generate_url = f"{self.api_url}/api/generate" |
| payload = { |
| "prompt": prompt, |
| "temperature": 0.1, |
| "repetition_penalty": 1.0, |
| "top-p": 0.9, |
| "top-k": 40, |
| "max_new_tokens": 512 |
| } |
| |
| print(f"[翻译API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})") |
| response = requests.post(generate_url, json=payload, timeout=30) |
| response.raise_for_status() |
| print("[翻译API] 生成请求成功") |
| |
| |
| result_url = f"{self.api_url}/api/generate_provider" |
| start_time = time.time() |
| full_translation = "" |
| last_chunk = "" |
| |
| while time.time() - start_time < timeout: |
| try: |
| result_response = requests.get(result_url, timeout=10) |
| result_data = result_response.json() |
| |
| |
| current_chunk = result_data.get("response", "") |
| full_translation += current_chunk |
| |
| |
| if result_data.get("done", False): |
| |
| print(f"[翻译API] 翻译完成: {full_translation}") |
| return full_translation |
| |
| time.sleep(0.05) |
| |
| except requests.exceptions.RequestException as e: |
| print(f"[翻译API] 轮询请求失败: {e}") |
| if time.time() - start_time > timeout: |
| break |
| continue |
| |
| print(f"[翻译API] 轮询超时,尝试第 {attempt + 1} 次重试") |
| |
| except requests.exceptions.RequestException as e: |
| print(f"[翻译API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") |
| if attempt < max_retries - 1: |
| wait_time = 2 ** attempt |
| print(f"[翻译API] 等待 {wait_time} 秒后重试...") |
| time.sleep(wait_time) |
| else: |
| return f"翻译失败: {str(e)}" |
| except Exception as e: |
| print(f"[翻译API] 翻译过程出错: {e}") |
| return f"翻译失败: {str(e)}" |
| |
| return "翻译超时,请检查API服务状态" |
|
|
| class SpeechTranslationPipeline: |
| def __init__(self, |
| tts_model_dir, tts_model_files, |
| asr_model_dir="ax_model", seq_len=132, |
| tts_dec_len=128, sample_rate=44100, tts_speed=0.8, |
| qwen_api_url=QWEN_API_URL): |
| self.tts_model_dir = tts_model_dir |
| self.tts_model_files = tts_model_files |
| self.asr_model_dir = asr_model_dir |
| self.seq_len = seq_len |
| self.tts_dec_len = tts_dec_len |
| self.sample_rate = sample_rate |
| self.tts_speed = tts_speed |
| self.qwen_api_url = qwen_api_url |
| |
| |
| self._init_asr_models() |
| |
| |
| self._init_tts_models() |
| |
| |
| self.translator = QwenTranslationAPI(api_url=qwen_api_url) |
| |
| |
| self._validate_files() |
| |
| def _init_asr_models(self): |
| """初始化语音识别相关模型""" |
| print("Initializing SenseVoice models...") |
| |
| |
| self.model_vad = AX_Fsmn_vad(self.asr_model_dir) |
| |
| |
| self.embed = SinusoidalPositionEncoder() |
| self.position_encoding = self.embed.get_position_encoding( |
| torch.randn(1, self.seq_len, 560)).numpy() |
| |
| |
| self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len) |
| |
| |
| tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
| self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path) |
| |
| print("SenseVoice models initialized successfully.") |
| |
| def _init_tts_models(self): |
| """初始化TTS相关模型""" |
| print("Initializing MeloTTS models...") |
| init_start = time.time() |
| |
| |
| enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]) |
| dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]) |
| |
| model_load_start = time.time() |
| self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) |
| self.sess_dec = axe.InferenceSession(dec_model) |
| print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms") |
| |
| |
| g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"]) |
| self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1) |
| |
| |
| self.tts_language = "ZH_MIX_EN" |
| self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])} |
| |
| |
| print(" Warming up TTS modules (loading language models, tokenizers, etc.)...") |
| warmup_start = time.time() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| try: |
| warmup_start_mix = time.time() |
| warmup_text_mix = "这是一个test测试。" |
| _, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id) |
| print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start_mix)*1000:.2f}ms") |
| except Exception as e: |
| print(f" Warning: Mixed warm-up failed: {e}") |
| |
| total_init_time = (time.time() - init_start) * 1000 |
| print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms ({total_init_time/1000:.2f}s)") |
| |
| def _validate_files(self): |
| """验证所有必需的文件都存在""" |
| |
| for key, filename in self.tts_model_files.items(): |
| filepath = os.path.join(self.tts_model_dir, filename) |
| if not os.path.exists(filepath): |
| raise FileNotFoundError(f"TTS模型文件不存在: {filepath}") |
| |
| |
| try: |
| response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5) |
| print("[API检查] 千问API服务连接正常") |
| except: |
| print("[API警告] 无法连接到千问API服务,请确保已启动API服务") |
|
|
| def speech_recognition(self, speech, fs): |
| """ |
| 第一步:语音识别(ASR) |
| """ |
| speech_lengths = len(speech) |
| |
| |
| print("Running VAD...") |
| vad_start_time = time.time() |
| res_vad = self.model_vad(speech)[0] |
| vad_segments = merge_vad(res_vad, 15 * 1000) |
| vad_time_cost = time.time() - vad_start_time |
| print(f"VAD processing time: {vad_time_cost:.2f} seconds") |
| print(f"VAD segments detected: {len(vad_segments)}") |
| |
| |
| print("Running ASR...") |
| asr_start_time = time.time() |
| all_results = "" |
| |
| |
| for i, segment in enumerate(vad_segments): |
| segment_start, segment_end = segment |
| start_sample = int(segment_start / 1000 * fs) |
| end_sample = min(int(segment_end / 1000 * fs), speech_lengths) |
| segment_speech = speech[start_sample:end_sample] |
| |
| |
| segment_filename = f"temp_segment_{i}.wav" |
| sf.write(segment_filename, segment_speech, fs) |
| |
| |
| try: |
| segment_res = self.model_bin( |
| segment_filename, |
| "auto", |
| True, |
| self.position_encoding, |
| tokenizer=self.tokenizer, |
| ) |
|
|
| all_results += segment_res |
| |
| |
| if os.path.exists(segment_filename): |
| os.remove(segment_filename) |
| |
| except Exception as e: |
| if os.path.exists(segment_filename): |
| os.remove(segment_filename) |
| print(f"Error processing segment {i}: {e}") |
| continue |
| |
| asr_time_cost = time.time() - asr_start_time |
| print(f"ASR processing time: {asr_time_cost:.2f} seconds") |
| print(f"ASR Result: {all_results}") |
| |
| return all_results.strip() |
| |
| def run_translation(self, text_content): |
| """ |
| 第二步:调用Qwen大模型API中英互译 |
| """ |
| print("Starting translation via API...") |
| translation_start_time = time.time() |
| |
| |
| translate_content = self.translator.translate(text_content) |
| |
| translation_time_cost = time.time() - translation_start_time |
| print(f"Translation processing time: {translation_time_cost:.2f} seconds") |
| print(f"Translation Result: {translate_content}") |
| |
| return translate_content |
| |
| def run_tts(self, translate_content, output_dir, output_wav=None): |
| """ |
| 第三步:使用TTS模型合成语音 |
| """ |
| output_path = os.path.join(output_dir, output_wav) |
| |
| try: |
| |
| if lang_detect_with_regex(translate_content) == "chinese": |
| translate_content = cn2an.transform(translate_content, "an2cn") |
| |
| print(f"TTS synthesis for text: {translate_content}") |
| |
| |
| sens = split_sentence(translate_content, language_str=self.tts_language) |
| print(f"Text split into {len(sens)} sentences") |
| |
| |
| audio_list = [] |
| |
| |
| for n, se in enumerate(sens): |
| |
| if self.tts_language in ['EN', 'ZH_MIX_EN']: |
| se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se) |
| |
| print(f"Processing sentence[{n}]: {se}") |
| |
| |
| phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer( |
| se, self.tts_language, symbol_to_id=self.symbol_to_id) |
| |
| |
| encoder_start = time.time() |
| z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={ |
| 'phone': phones, 'g': self.tts_g, |
| 'tone': tones, 'language': lang_ids, |
| 'noise_scale': np.array([0], dtype=np.float32), |
| 'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32), |
| 'noise_scale_w': np.array([0], dtype=np.float32), |
| 'sdp_ratio': np.array([0], dtype=np.float32)}) |
| print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms") |
| |
| |
| word2pronoun = calc_word2pronoun(word2ph, pronoun_lens) |
| |
| pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len) |
| |
| audio_len = audio_len[0] |
| sub_audio_list = [] |
| |
| for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)): |
| zp_slice = z_p[..., zs] |
| |
| |
| sub_dec_len = zp_slice.shape[-1] |
| |
| sub_audio_len = 512 * sub_dec_len |
| |
| |
| if zp_slice.shape[-1] < self.tts_dec_len: |
| zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) |
| |
| decoder_start = time.time() |
| audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten() |
| |
| |
| audio_start = 0 |
| if len(sub_audio_list) > 0: |
| if pn_slices[i - 1].stop > ps.start: |
| |
| audio_start = 512 * word2pronoun[ps.start] |
| |
| audio_end = sub_audio_len |
| if i < len(pn_slices) - 1: |
| if ps.stop > pn_slices[i + 1].start: |
| |
| audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1] |
| |
| audio = audio[audio_start:audio_end] |
| print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms") |
| sub_audio_list.append(audio) |
| |
| |
| sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) |
| audio_list.append(sub_audio) |
| |
| |
| audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed) |
| |
| |
| sf.write(output_path, audio, self.sample_rate) |
| print(f"TTS audio saved to {output_path}") |
| |
| return output_path |
| |
| except Exception as e: |
| print(f"TTS synthesis failed: {e}") |
| import traceback |
| traceback.print_exc() |
| raise e |
| |
| def full_pipeline(self, speech, fs, output_dir=None, output_tts=None): |
| """ |
| 完整Pipeline:语音识别 -> 翻译 -> TTS合成 |
| """ |
| |
| |
| print("\n----------------------VAD+ASR----------------------------\n") |
| start_time = time.time() |
| text_content = self.speech_recognition(speech, fs) |
| asr_time = time.time() - start_time |
| print(f"语音识别耗时: {asr_time:.2f} 秒") |
| |
| if not text_content or text_content.strip() == "": |
| raise ValueError("ASR未能识别出有效文本") |
| |
| |
| print("\n---------------------Qwen翻译---------------------------\n") |
| start_time = time.time() |
| translate_content = self.run_translation(text_content) |
| translate_time = time.time() - start_time |
| print(f"翻译耗时: {translate_time:.2f} 秒") |
| |
| |
| print("-------------------------TTS-------------------------------\n") |
| start_time = time.time() |
| output_path = self.run_tts(translate_content, output_dir, output_tts) |
| tts_time = time.time() - start_time |
| print(f"TTS合成耗时: {tts_time:.2f} 秒") |
| |
| return { |
| "original_text": text_content, |
| "translated_text": translate_content, |
| "audio_path": output_path |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline") |
| parser.add_argument("--audio_file", type=str, required=True, help="Input audio file path") |
| parser.add_argument("--output_dir", type=str, default="./output", help="Output directory") |
| parser.add_argument("--output_tts", type=str, default="output.wav", help="Output TTS file name") |
| parser.add_argument("--api_url", type=str, default=QWEN_API_URL, help="Qwen API server URL") |
| |
| args = parser.parse_args() |
| print("-------------------START------------------------\n") |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| print(f"Processing audio file: {args.audio_file}") |
| |
| speech, fs = librosa.load(args.audio_file, sr=None) |
| if fs != 16000: |
| print(f"Resampling audio from {fs}Hz to 16000Hz") |
| speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000) |
| fs = 16000 |
| audio_duration = librosa.get_duration(y=speech, sr=fs) |
| |
| |
| pipeline = SpeechTranslationPipeline( |
| tts_model_dir=TTS_MODEL_DIR, |
| tts_model_files=TTS_MODEL_FILES, |
| asr_model_dir="ax_model", |
| seq_len=132, |
| tts_dec_len=128, |
| sample_rate=44100, |
| tts_speed=0.8, |
| qwen_api_url=args.api_url |
| ) |
| |
| start_time = time.time() |
| try: |
| |
| result = pipeline.full_pipeline(speech, fs, args.output_dir, args.output_tts) |
| |
| print("\n" + "="*50) |
| print("speech translate 完成!") |
| print("="*50 + "\n") |
| print(f"原始音频: {args.audio_file}") |
| print(f"原始文本: {result['original_text']}") |
| print(f"翻译文本: {result['translated_text']}") |
| print(f"生成音频: {result['audio_path']}") |
| |
| |
| result_file = os.path.join(args.output_dir, "pipeline_result.txt") |
| with open(result_file, 'w', encoding='utf-8') as f: |
| f.write(f"原始音频: {args.audio_file}\n") |
| f.write(f"识别文本: {result['original_text']}\n") |
| f.write(f"翻译结果: {result['translated_text']}\n") |
| f.write(f"合成音频: {result['audio_path']}\n") |
| |
| time_cost = time.time() - start_time |
| rtf = time_cost / audio_duration |
| print(f"Inference time for {args.audio_file}: {time_cost:.2f} seconds") |
| print(f"Audio duration: {audio_duration:.2f} seconds") |
| print(f"RTF: {rtf:.2f}\n") |
| except Exception as e: |
| print(f"Pipeline执行失败: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| if __name__ == "__main__": |
| main() |