| |
| import tempfile |
| import os |
| |
| |
| import time |
| import librosa |
| import torch |
| import argparse |
| import soundfile as sf |
| |
| import cn2an |
| import requests |
| import re |
| import numpy as np |
| import onnxruntime as ort |
| import axengine as axe |
| import threading |
| import queue |
| from collections import deque |
|
|
| |
| from model import SinusoidalPositionEncoder |
| from utils.ax_model_bin import AX_SenseVoiceSmall |
| from utils.ax_vad_bin import AX_Fsmn_vad |
| from utils.vad_utils import merge_vad |
| from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer |
|
|
| |
| from libmelotts.python.split_utils import split_sentence |
| from libmelotts.python.text import cleaned_text_to_sequence |
| from libmelotts.python.text.cleaner import clean_text |
| from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP |
|
|
| |
| |
| TTS_MODEL_DIR = "libmelotts/models" |
| TTS_MODEL_FILES = { |
| "g": "g-zh_mix_en.bin", |
| "encoder": "encoder-zh.onnx", |
| "decoder": "decoder-zh.axmodel" |
| } |
|
|
| |
| QWEN_API_URL = "http://10.126.29.158:8000" |
|
|
| |
| def intersperse(lst, item): |
| result = [item] * (len(lst) * 2 + 1) |
| result[1::2] = lst |
| return result |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| def get_text_for_tts_infer(text, language_str, symbol_to_id=None): |
| """修复版音素处理:确保所有数组长度一致""" |
| try: |
| norm_text, phone, tone, word2ph = clean_text(text, language_str) |
| |
| |
| phone_mapping = { |
| 'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '', |
| 'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '', |
| 'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '', |
| } |
| |
| |
| processed_phone = [] |
| processed_tone = [] |
| removed_symbols = set() |
| |
| for p, t in zip(phone, tone): |
| if p in phone_mapping: |
| |
| removed_symbols.add(p) |
| elif p in symbol_to_id: |
| |
| processed_phone.append(p) |
| processed_tone.append(t) |
| else: |
| |
| removed_symbols.add(p) |
| |
| |
| if removed_symbols: |
| print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素: {sorted(removed_symbols)}") |
| print(f"[音素过滤] 处理后音素序列长度: {len(processed_phone)}") |
| print(f"[音素过滤] 处理后音调序列长度: {len(processed_tone)}") |
| |
| |
| if not processed_phone: |
| print("[警告] 没有有效音素,使用默认中文音素") |
| processed_phone = ['ni', 'hao'] |
| processed_tone = ['1', '3'] |
| word2ph = [1, 1] |
| |
| |
| if len(processed_phone) != len(phone): |
| print(f"[警告] 音素序列长度变化: {len(phone)} -> {len(processed_phone)}") |
| |
| word2ph = [1] * len(processed_phone) |
| |
| phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id) |
|
|
| phone = intersperse(phone, 0) |
| tone = intersperse(tone, 0) |
| language = intersperse(language, 0) |
|
|
| phone = np.array(phone, dtype=np.int32) |
| tone = np.array(tone, dtype=np.int32) |
| language = np.array(language, dtype=np.int32) |
| word2ph = np.array(word2ph, dtype=np.int32) * 2 |
| word2ph[0] += 1 |
|
|
| return phone, tone, language, norm_text, word2ph |
| |
| except Exception as e: |
| print(f"[错误] 文本处理失败: {e}") |
| import traceback |
| traceback.print_exc() |
| raise e |
|
|
| def audio_numpy_concat(segment_data_list, sr, speed=1.): |
| audio_segments = [] |
| for segment_data in segment_data_list: |
| audio_segments += segment_data.reshape(-1).tolist() |
| audio_segments += [0] * int((sr * 0.05) / speed) |
| audio_segments = np.array(audio_segments).astype(np.float32) |
| return audio_segments |
|
|
| def merge_sub_audio(sub_audio_list, pad_size, audio_len): |
| if pad_size > 0: |
| for i in range(len(sub_audio_list) - 1): |
| sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] |
| sub_audio_list[i][-pad_size:] /= 2 |
| if i > 0: |
| sub_audio_list[i] = sub_audio_list[i][pad_size:] |
|
|
| sub_audio = np.concatenate(sub_audio_list, axis=-1) |
| return sub_audio[:audio_len] |
|
|
| def calc_word2pronoun(word2ph, pronoun_lens): |
| indice = [0] |
| for ph in word2ph[:-1]: |
| indice.append(indice[-1] + ph) |
| word2pronoun = [] |
| for i, ph in zip(indice, word2ph): |
| word2pronoun.append(np.sum(pronoun_lens[i : i + ph])) |
| return word2pronoun |
|
|
| def generate_slices(word2pronoun, dec_len): |
| pn_start, pn_end = 0, 0 |
| zp_start, zp_end = 0, 0 |
| zp_len = 0 |
| pn_slices = [] |
| zp_slices = [] |
| while pn_end < len(word2pronoun): |
| if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len: |
| zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end]) |
| zp_start = zp_end - zp_len |
| pn_start = pn_end - 2 |
| else: |
| zp_len = 0 |
| zp_start = zp_end |
| pn_start = pn_end |
| |
| while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len: |
| zp_len += word2pronoun[pn_end] |
| pn_end += 1 |
| zp_end = zp_start + zp_len |
| pn_slices.append(slice(pn_start, pn_end)) |
| zp_slices.append(slice(zp_start, zp_end)) |
| return pn_slices, zp_slices |
|
|
| |
| def lang_detect_with_regex(text): |
| text_without_digits = re.sub(r'\d+', '', text) |
| |
| if not text_without_digits: |
| return 'unknown' |
| |
| if re.search(r'[\u4e00-\u9fff]', text_without_digits): |
| return 'chinese' |
| else: |
| if re.search(r'[a-zA-Z]', text_without_digits): |
| return 'english' |
| else: |
| return 'unknown' |
|
|
| class QwenTranslationAPI: |
| def __init__(self, api_url=QWEN_API_URL): |
| self.api_url = api_url |
| self.session_id = f"speech_translate_{int(time.time())}" |
| self.last_reset_time = time.time() |
| self.request_count = 0 |
| self.max_requests_before_reset = 10 |
| |
| def reset_context(self): |
| """重置API上下文""" |
| try: |
| reset_url = f"{self.api_url}/api/reset" |
| response = requests.post(reset_url, json={}, timeout=5) |
| if response.status_code == 200: |
| print("[翻译API] ✓ 上下文重置成功") |
| self.last_reset_time = time.time() |
| self.request_count = 0 |
| return True |
| else: |
| print(f"[翻译API] ✗ 重置失败,状态码: {response.status_code}, 响应: {response.text}") |
| except Exception as e: |
| print(f"[翻译API] ✗ 重置上下文失败: {e}") |
| return False |
|
|
| def check_and_reset_if_needed(self): |
| """检查是否需要重置上下文""" |
| current_time = time.time() |
| if (self.request_count >= 10 or |
| current_time - self.last_reset_time > 120): |
| print(f"[翻译API] 重置 (请求数: {self.request_count}, 时间: {current_time - self.last_reset_time:.1f}秒)") |
| return self.reset_context() |
| return True |
| |
| def translate(self, text_content, max_retries=3, timeout=120): |
| if not text_content or text_content.strip() == "": |
| return "输入文本为空" |
| |
| |
| if len(text_content.strip()) < 3: |
| return text_content |
| |
| if lang_detect_with_regex(text_content)=='chinese': |
| prompt_f = "翻译成英文" |
| else: |
| prompt_f= "翻译成中文" |
|
|
| prompt = f"{prompt_f}:{text_content}" |
| print(f"[翻译API] 发送请求: {prompt}") |
| |
| |
| self.check_and_reset_if_needed() |
| |
| for attempt in range(max_retries): |
| try: |
| generate_url = f"{self.api_url}/api/generate" |
| payload = { |
| "prompt": prompt, |
| "temperature": 0.1, |
| "repetition_penalty": 1.0, |
| "top-p": 0.9, |
| "top-k": 40, |
| "max_new_tokens": 512 |
| } |
| |
| print(f"[翻译API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})") |
| response = requests.post(generate_url, json=payload, timeout=30) |
| response.raise_for_status() |
| print("[翻译API] 生成请求成功") |
| |
| result_url = f"{self.api_url}/api/generate_provider" |
| start_time = time.time() |
| full_translation = "" |
| error_detected = False |
| |
| while time.time() - start_time < timeout: |
| try: |
| result_response = requests.get(result_url, timeout=10) |
| result_data = result_response.json() |
| |
| current_chunk = result_data.get("response", "") |
| |
| |
| if "error:" in current_chunk.lower() or "setkvcache failed" in current_chunk.lower(): |
| print(f"[翻译API] ✗ 检测到错误: {current_chunk}") |
| error_detected = True |
| print("[翻译API] 立即重置上下文...") |
| self.reset_context() |
| break |
| |
| full_translation += current_chunk |
| |
| if result_data.get("done", False): |
| if full_translation and len(full_translation.strip()) > 0: |
| self.request_count += 1 |
| print(f"[翻译API] ✓ 翻译完成: {full_translation}") |
| return full_translation |
| else: |
| print(f"[翻译API] ✗ 翻译结果为空") |
| break |
| |
| time.sleep(0.05) |
| |
| except requests.exceptions.RequestException as e: |
| print(f"[翻译API] 轮询请求失败: {e}") |
| if time.time() - start_time > timeout: |
| break |
| time.sleep(0.5) |
| continue |
| |
| if error_detected: |
| if attempt < max_retries - 1: |
| wait_time = 1 |
| print(f"[翻译API] 等待 {wait_time} 秒后重试...") |
| time.sleep(wait_time) |
| continue |
| else: |
| print("[翻译API] 达到最大重试次数,返回原文") |
| return text_content |
| |
| print(f"[翻译API] 轮询超时,尝试第 {attempt + 1} 次重试") |
| |
| except requests.exceptions.RequestException as e: |
| print(f"[翻译API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") |
| if attempt < max_retries - 1: |
| wait_time = 2 ** attempt |
| print(f"[翻译API] 等待 {wait_time} 秒后重试...") |
| time.sleep(wait_time) |
| else: |
| return text_content |
| except Exception as e: |
| print(f"[翻译API] 翻译过程出错: {e}") |
| if attempt < max_retries - 1: |
| time.sleep(1) |
| continue |
| return text_content |
| |
| print("[翻译API] 翻译超时,返回原文") |
| return text_content |
|
|
| class AudioResampler: |
| """音频重采样器""" |
| def __init__(self, target_sr=16000): |
| self.target_sr = target_sr |
| |
| def resample_audio(self, audio_data, original_sr): |
| """重采样音频到目标采样率,asr统一输入16000Hz""" |
| if original_sr == self.target_sr: |
| return audio_data |
| |
| print(f"[重采样] {original_sr}Hz -> {self.target_sr}Hz") |
| return librosa.resample(y=audio_data, orig_sr=original_sr, target_sr=self.target_sr) |
| |
| def resample_chunk(self, audio_chunk, original_sr): |
| """重采样音频块:长音频进行过冲采样后,音频块可以不做重采样""" |
| if original_sr == self.target_sr: |
| return audio_chunk |
| |
| if len(audio_chunk) < 1000: |
| return self._linear_resample(audio_chunk, original_sr, self.target_sr) |
| else: |
| return librosa.resample(y=audio_chunk, orig_sr=original_sr, target_sr=self.target_sr) |
| |
| def _linear_resample(self, audio_chunk, original_sr, target_sr): |
| """线性插值重采样""" |
| ratio = target_sr / original_sr |
| old_length = len(audio_chunk) |
| new_length = int(old_length * ratio) |
| |
| old_indices = np.arange(old_length) |
| new_indices = np.linspace(0, old_length - 1, new_length) |
| |
| resampled = np.interp(new_indices, old_indices, audio_chunk) |
| return resampled |
|
|
| class StreamProcessor: |
| """流式处理""" |
| def __init__(self, pipeline, chunk_duration=7.0, overlap_duration=0.01, target_sr=16000): |
| self.pipeline = pipeline |
| self.chunk_duration = chunk_duration |
| self.overlap_duration = overlap_duration |
| self.target_sr = target_sr |
| self.chunk_samples = int(chunk_duration * target_sr) |
| self.overlap_samples = int(overlap_duration * target_sr) |
| self.audio_buffer = deque() |
| self.result_queue = queue.Queue() |
| self.is_running = False |
| self.processing_thread = None |
| self.resampler = AudioResampler(target_sr=target_sr) |
| self.segment_counter = 0 |
| self.processed_texts = set() |
| |
| def start_processing(self): |
| """开始流式处理""" |
| self.is_running = True |
| self.processing_thread = threading.Thread(target=self._process_loop) |
| self.processing_thread.daemon = True |
| self.processing_thread.start() |
| |
| def stop_processing(self): |
| """停止流式处理""" |
| self.is_running = False |
| if self.processing_thread: |
| self.processing_thread.join(timeout=5) |
| |
| def add_audio_chunk(self, audio_chunk, original_sr=None): |
| """添加音频块到缓冲区""" |
| if original_sr and original_sr != self.target_sr: |
| audio_chunk = self.resampler.resample_chunk(audio_chunk, original_sr) |
| |
| self.audio_buffer.append(audio_chunk) |
| |
| def get_next_result(self, timeout=1.0): |
| """获取下一个处理结果""" |
| try: |
| return self.result_queue.get(timeout=timeout) |
| except queue.Empty: |
| return None |
| |
| def _process_loop(self): |
| """处理循环""" |
| accumulated_audio = np.array([], dtype=np.float32) |
| last_asr_result = "" |
| |
| while self.is_running: |
| if len(self.audio_buffer) > 0: |
| audio_chunk = self.audio_buffer.popleft() |
| accumulated_audio = np.concatenate([accumulated_audio, audio_chunk]) |
| |
| |
| if len(accumulated_audio) >= self.chunk_samples: |
| |
| process_chunk = accumulated_audio[:self.chunk_samples] |
| accumulated_audio = accumulated_audio[self.chunk_samples - self.overlap_samples:] |
| |
| try: |
| |
| asr_result = self._stream_asr(process_chunk) |
| |
| |
| |
| |
| |
| if (asr_result and asr_result.strip() and |
| |
| asr_result != last_asr_result and |
| asr_result not in self.processed_texts): |
| |
| print(f"[实时ASR] {asr_result}") |
| last_asr_result = asr_result |
| self.processed_texts.add(asr_result) |
| |
| |
| try: |
| translation_result = self.pipeline.run_translation(asr_result) |
| |
| |
| if (translation_result and |
| translation_result != asr_result and |
| "翻译失败" not in translation_result and |
| "error:" not in translation_result.lower() and |
| "输入文本为空" not in translation_result): |
| |
| print(f"[实时翻译] {translation_result}") |
| |
| |
| try: |
| self.segment_counter += 1 |
| tts_filename = f"stream_segment_{self.segment_counter:04d}.wav" |
| tts_start_time = time.time() |
| |
| tts_path = self.pipeline.run_tts( |
| translation_result, |
| self.pipeline.output_dir, |
| tts_filename |
| ) |
| |
| tts_time = time.time() - tts_start_time |
| print(f"[实时TTS] 音频已保存: {tts_path} (耗时: {tts_time:.2f}秒)") |
| |
| |
| self.result_queue.put({ |
| 'type': 'complete', |
| 'original': asr_result, |
| 'translated': translation_result, |
| 'audio_path': tts_path, |
| 'timestamp': time.time(), |
| 'segment_id': self.segment_counter |
| }) |
| |
| except Exception as tts_error: |
| print(f"[实时TTS错误] {tts_error}") |
| import traceback |
| traceback.print_exc() |
| else: |
| print(f"[实时翻译] 翻译结果无效,已跳过") |
| |
| except Exception as translation_error: |
| print(f"[实时翻译错误] {translation_error}") |
| else: |
| if asr_result == last_asr_result: |
| print(f"[实时ASR] 重复内容已跳过: {asr_result}") |
| |
| except Exception as e: |
| print(f"[流式处理错误] {e}") |
| import traceback |
| traceback.print_exc() |
| |
| time.sleep(0.01) |
| |
| def _stream_asr(self, audio_chunk): |
| """流式ASR识别(带VAD)""" |
| try: |
| |
| |
| |
| res_vad = self.pipeline.model_vad(audio_chunk)[0] |
| vad_segments = merge_vad(res_vad, 15 * 1000) |
| |
| |
| if not vad_segments or len(vad_segments) == 0: |
| print(f"[VAD] 未检测到语音活动,跳过此音频块") |
| return "" |
| |
| print(f"[VAD] 检测到 {len(vad_segments)} 个语音段") |
| |
| |
| |
| |
| all_results = "" |
| |
| for i, segment in enumerate(vad_segments): |
| segment_start, segment_end = segment |
| start_sample = int(segment_start / 1000 * self.target_sr) |
| end_sample = min(int(segment_end / 1000 * self.target_sr), len(audio_chunk)) |
| segment_audio = audio_chunk[start_sample:end_sample] |
| |
| |
| if len(segment_audio) < int(0.3 * self.target_sr): |
| continue |
| |
| |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: |
| sf.write(temp_file.name, segment_audio, self.target_sr) |
| temp_filename = temp_file.name |
| |
| try: |
| |
| segment_result = self.pipeline.model_bin( |
| temp_filename, |
| "auto", |
| True, |
| self.pipeline.position_encoding, |
| tokenizer=self.pipeline.tokenizer, |
| ) |
| |
| if segment_result and segment_result.strip(): |
| all_results += segment_result + " " |
| |
| |
| os.unlink(temp_filename) |
| |
| except Exception as e: |
| print(f"[ASR错误] 处理VAD段 {i} 时出错: {e}") |
| if os.path.exists(temp_filename): |
| os.unlink(temp_filename) |
| continue |
| |
| return all_results.strip() |
| |
| except Exception as e: |
| print(f"[ASR错误] {e}") |
| return "" |
|
|
| class SpeechTranslationPipeline: |
| def __init__(self, |
| tts_model_dir, tts_model_files, |
| asr_model_dir="ax_model", seq_len=132, |
| tts_dec_len=128, sample_rate=44100, tts_speed=0.8, |
| qwen_api_url=QWEN_API_URL, target_sr=16000, |
| output_dir="./output"): |
| self.tts_model_dir = tts_model_dir |
| self.tts_model_files = tts_model_files |
| self.asr_model_dir = asr_model_dir |
| self.seq_len = seq_len |
| self.tts_dec_len = tts_dec_len |
| self.sample_rate = sample_rate |
| self.tts_speed = tts_speed |
| self.qwen_api_url = qwen_api_url |
| self.target_sr = target_sr |
| self.output_dir = output_dir |
| |
| |
| os.makedirs(self.output_dir, exist_ok=True) |
| |
| |
| self.resampler = AudioResampler(target_sr=target_sr) |
| |
| |
| self._init_asr_models() |
| |
| |
| self._init_tts_models() |
| |
| |
| self.translator = QwenTranslationAPI(api_url=qwen_api_url) |
| |
| |
| self.stream_processor = StreamProcessor(self, target_sr=target_sr) |
| |
| |
| self._validate_files() |
| |
| |
| print("[初始化] 重置API上下文...") |
| self.translator.reset_context() |
| |
| def _init_asr_models(self): |
| """初始化语音识别相关模型""" |
| print("Initializing SenseVoice models...") |
| |
| self.model_vad = AX_Fsmn_vad(self.asr_model_dir) |
| |
| self.embed = SinusoidalPositionEncoder() |
| self.position_encoding = self.embed.get_position_encoding( |
| torch.randn(1, self.seq_len, 560)).numpy() |
| |
| self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len) |
| |
| tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
| self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path) |
| |
| print("SenseVoice models initialized successfully.") |
| |
| def _init_tts_models(self): |
| """初始化TTS相关模型""" |
| print("Initializing MeloTTS models...") |
| init_start = time.time() |
| |
| enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]) |
| dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]) |
| |
| model_load_start = time.time() |
| self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) |
| self.sess_dec = axe.InferenceSession(dec_model) |
| print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms") |
| |
| g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"]) |
| self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1) |
| |
| self.tts_language = "ZH_MIX_EN" |
| self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])} |
| |
| print(" Warming up TTS modules...") |
| warmup_start = time.time() |
| |
| try: |
| warmup_text_mix = "这是一个test测试。" |
| _, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id) |
| print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start)*1000:.2f}ms") |
| except Exception as e: |
| print(f" Warning: Mixed warm-up failed: {e}") |
| |
| total_init_time = (time.time() - init_start) * 1000 |
| print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms") |
| |
| def _validate_files(self): |
| """验证所有必需的文件都存在""" |
| for key, filename in self.tts_model_files.items(): |
| filepath = os.path.join(self.tts_model_dir, filename) |
| if not os.path.exists(filepath): |
| raise FileNotFoundError(f"TTS模型文件不存在: {filepath}") |
| |
| try: |
| response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5) |
| print("[API检查] 千问API服务连接正常") |
| except: |
| print("[API警告] 无法连接到千问API服务,请确保已启动API服务") |
|
|
| def start_stream_processing(self): |
| """开始流式处理""" |
| self.stream_processor.start_processing() |
| print("[流式处理] 已启动") |
| |
| def stop_stream_processing(self): |
| """停止流式处理""" |
| self.stream_processor.stop_processing() |
| print("[流式处理] 已停止") |
| |
| def process_audio_stream(self, audio_chunk, original_sr=None): |
| """处理音频流数据""" |
| self.stream_processor.add_audio_chunk(audio_chunk, original_sr) |
| |
| def get_stream_results(self): |
| """获取流式处理结果""" |
| return self.stream_processor.get_next_result() |
|
|
| def load_and_resample_audio(self, audio_file): |
| """加载音频并重采样到目标采样率""" |
| print(f"加载音频文件: {audio_file}") |
| speech, original_sr = librosa.load(audio_file, sr=None) |
| |
| audio_duration = len(speech) / original_sr |
| print(f"原始音频: {original_sr}Hz, 时长: {audio_duration:.2f}秒") |
| |
| if original_sr != self.target_sr: |
| speech = self.resampler.resample_audio(speech, original_sr) |
| print(f"重采样后: {self.target_sr}Hz, 时长: {len(speech)/self.target_sr:.2f}秒") |
| |
| return speech, self.target_sr |
| |
| def run_translation(self, text_content): |
| """调用Qwen大模型API中英互译""" |
| print("Starting translation via API...") |
| translation_start_time = time.time() |
| |
| translate_content = self.translator.translate(text_content) |
| |
| translation_time_cost = time.time() - translation_start_time |
| print(f"Translation processing time: {translation_time_cost:.2f} seconds") |
| print(f"Translation Result: {translate_content}") |
| |
| return translate_content |
| |
| def run_tts(self, translate_content, output_dir, output_wav=None): |
| """使用TTS模型合成语音""" |
| output_path = os.path.join(output_dir, output_wav) |
| |
| try: |
| if lang_detect_with_regex(translate_content) == "chinese": |
| translate_content = cn2an.transform(translate_content, "an2cn") |
| |
| print(f"TTS synthesis for text: {translate_content}") |
| |
| sens = split_sentence(translate_content, language_str=self.tts_language) |
| print(f"Text split into {len(sens)} sentences") |
| |
| audio_list = [] |
| |
| for n, se in enumerate(sens): |
| if self.tts_language in ['EN', 'ZH_MIX_EN']: |
| se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se) |
| |
| print(f"Processing sentence[{n}]: {se}") |
| |
| phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer( |
| se, self.tts_language, symbol_to_id=self.symbol_to_id) |
| |
| encoder_start = time.time() |
| z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={ |
| 'phone': phones, 'g': self.tts_g, |
| 'tone': tones, 'language': lang_ids, |
| 'noise_scale': np.array([0], dtype=np.float32), |
| 'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32), |
| 'noise_scale_w': np.array([0], dtype=np.float32), |
| 'sdp_ratio': np.array([0], dtype=np.float32)}) |
| print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms") |
| |
| word2pronoun = calc_word2pronoun(word2ph, pronoun_lens) |
| pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len) |
| |
| audio_len = audio_len[0] |
| sub_audio_list = [] |
| |
| for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)): |
| zp_slice = z_p[..., zs] |
| |
| sub_dec_len = zp_slice.shape[-1] |
| sub_audio_len = 512 * sub_dec_len |
| |
| if zp_slice.shape[-1] < self.tts_dec_len: |
| zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) |
| |
| decoder_start = time.time() |
| audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten() |
| |
| audio_start = 0 |
| if len(sub_audio_list) > 0: |
| if pn_slices[i - 1].stop > ps.start: |
| audio_start = 512 * word2pronoun[ps.start] |
| |
| audio_end = sub_audio_len |
| if i < len(pn_slices) - 1: |
| if ps.stop > pn_slices[i + 1].start: |
| audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1] |
| |
| audio = audio[audio_start:audio_end] |
| print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms") |
| sub_audio_list.append(audio) |
| |
| sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) |
| audio_list.append(sub_audio) |
| |
| audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed) |
| |
| sf.write(output_path, audio, self.sample_rate) |
| print(f"TTS audio saved to {output_path}") |
| |
| return output_path |
| |
| except Exception as e: |
| print(f"TTS synthesis failed: {e}") |
| import traceback |
| traceback.print_exc() |
| raise e |
|
|
| def process_long_audio_stream(self, audio_file, chunk_size=64000): |
| """ |
| 处理长音频文件的流式模拟 |
| chunk_size增加到64000(4秒 * 16000Hz),与StreamProcessor的chunk_duration匹配 |
| 4秒有点短,改到7秒感觉更好点 |
| """ |
| print(f"[流式处理] 开始处理长音频: {audio_file}") |
| |
| |
| speech, fs = self.load_and_resample_audio(audio_file) |
| |
| |
| self.start_stream_processing() |
| |
| total_chunks = (len(speech) + chunk_size - 1) // chunk_size |
| print(f"[流式处理] 音频总长度: {len(speech)/fs:.2f}秒, 分块数: {total_chunks}") |
| |
| |
| all_results = [] |
| |
| |
| chunk_count = 0 |
| for i in range(0, len(speech), chunk_size): |
| chunk = speech[i:i+chunk_size] |
| chunk_count += 1 |
| |
| |
| if len(chunk) < chunk_size: |
| padding_size = chunk_size - len(chunk) |
| chunk = np.concatenate([chunk, np.zeros(padding_size, dtype=np.float32)]) |
| print(f"\n[流式处理] 处理音频块 {chunk_count}/{total_chunks} (最后一块,已填零 {padding_size} 样本)") |
| else: |
| print(f"\n[流式处理] 处理音频块 {chunk_count}/{total_chunks}") |
| |
| self.process_audio_stream(chunk, fs) |
| |
| |
| result = self.get_stream_results() |
| while result: |
| print(f"\n{'='*70}") |
| print(f"[实时结果 #{len(all_results) + 1}]") |
| print(f"段落ID: {result['segment_id']}") |
| print(f"原文: {result['original']}") |
| print(f"翻译: {result['translated']}") |
| print(f"音频: {result['audio_path']}") |
| print(f"{'='*70}") |
| all_results.append(result) |
| result = self.get_stream_results() |
| |
| time.sleep(0.1) |
| |
| |
| |
| max_wait_time = 20 |
| wait_start = time.time() |
| |
| while time.time() - wait_start < max_wait_time: |
| result = self.get_stream_results() |
| if result: |
| print(f"\n{'='*70}") |
| print(f"[实时结果 #{len(all_results) + 1}]") |
| print(f"段落ID: {result['segment_id']}") |
| print(f"原文: {result['original']}") |
| print(f"翻译: {result['translated']}") |
| print(f"音频: {result['audio_path']}") |
| print(f"{'='*70}") |
| all_results.append(result) |
| wait_start = time.time() |
| else: |
| time.sleep(0.2) |
| |
| |
| self.stop_stream_processing() |
| |
| print(f"\n[流式处理] 完成!共处理 {len(all_results)} 个有效结果") |
| return all_results |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="实时语音翻译pipeline") |
| parser.add_argument("--audio_file", type=str, required=True, help="输入音频文件路径") |
| parser.add_argument("--output_dir", type=str, default="./output", help="输出目录") |
| parser.add_argument("--api_url", type=str, default=QWEN_API_URL, help="Qwen API服务器URL") |
| parser.add_argument("--target_sr", type=int, default=16000, help="ASR目标采样率 (默认: 16000)") |
| parser.add_argument("--chunk_duration", type=float, default=7.0, help="音频块时长(秒) (默认: 7.0)") |
| parser.add_argument("--overlap_duration", type=float, default=0.01, help="重叠时长(秒) (默认: 0.1)") |
| |
| args = parser.parse_args() |
| print("-------------------实时语音翻译pipeline-------------------\n") |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| print(f"处理音频文件: {args.audio_file}") |
| print(f"输出目录: {args.output_dir}") |
| print(f"音频块时长: {args.chunk_duration}秒") |
| print(f"重叠时长: {args.overlap_duration}秒\n") |
| |
| |
| pipeline = SpeechTranslationPipeline( |
| tts_model_dir=TTS_MODEL_DIR, |
| tts_model_files=TTS_MODEL_FILES, |
| asr_model_dir="ax_model", |
| seq_len=132, |
| tts_dec_len=128, |
| sample_rate=44100, |
| tts_speed=0.8, |
| qwen_api_url=args.api_url, |
| target_sr=args.target_sr, |
| output_dir=args.output_dir |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| start_time = time.time() |
| try: |
| |
| print("使用流式处理模式(包含TTS)...") |
| print("="*70 + "\n") |
| |
| |
| chunk_size = int(args.chunk_duration * args.target_sr) |
| results = pipeline.process_long_audio_stream(args.audio_file, chunk_size=chunk_size) |
| |
| print("\n" + "="*70) |
| print(" 处理完成") |
| print("="*70) |
| print(f"\n 成功处理 {len(results)} 个有效翻译段落\n") |
| |
| |
| if results: |
| print("所有翻译结果:") |
| print("-" * 70) |
| for idx, result in enumerate(results, 1): |
| print(f"\n【段落 {idx}】(ID: {result['segment_id']})") |
| print(f" 原文: {result['original']}") |
| print(f" 译文: {result['translated']}") |
| print(f" 音频: {result['audio_path']}") |
| print(f" 时间: {time.strftime('%H:%M:%S', time.localtime(result['timestamp']))}") |
| print("-" * 70) |
| |
| |
| result_file = os.path.join(args.output_dir, "stream_results.txt") |
| with open(result_file, 'w', encoding='utf-8') as f: |
| f.write(f"流式翻译+TTS结果 - {args.audio_file}\n") |
| f.write(f"处理时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") |
| f.write(f"音频块时长: {args.chunk_duration}秒, 重叠时长: {args.overlap_duration}秒\n") |
| f.write("="*70 + "\n\n") |
| for idx, result in enumerate(results, 1): |
| f.write(f"【段落 {idx}】(ID: {result['segment_id']})\n") |
| f.write(f"原文: {result['original']}\n") |
| f.write(f"译文: {result['translated']}\n") |
| f.write(f"音频: {result['audio_path']}\n") |
| f.write(f"时间: {time.strftime('%H:%M:%S', time.localtime(result['timestamp']))}\n") |
| f.write("\n" + "-"*70 + "\n\n") |
| print(f"\n✓ 结果已保存到: {result_file}") |
| |
| |
| audio_files = [r['audio_path'] for r in results] |
| print(f"\n 生成 {len(audio_files)} 个TTS音频文件:") |
| for audio_file in audio_files: |
| print(f" - {audio_file}") |
| else: |
| print("\n 未获取到有效的翻译结果") |
| |
| print("="*70) |
| |
| |
| total_time = time.time() - start_time |
| print(f"\n总处理时间: {total_time:.2f} 秒") |
| |
| except Exception as e: |
| print(f"Pipeline执行失败: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|