Spaces:
Build error
Build error
| import gradio as gr | |
| import sys | |
| import threading | |
| import queue | |
| from io import TextIOBase | |
| import datetime | |
| import subprocess | |
| import os | |
| from inference import postprocess_inst_names | |
| # 如果你的 inference、convert 等逻辑和原来一致,可以直接用 | |
| from inference import inference_patch | |
| from convert import abc2xml, xml2, pdf2img | |
| # 读取 prompt 组合 | |
| with open('prompts.txt', 'r') as f: | |
| prompts = f.readlines() | |
| valid_combinations = set() | |
| for prompt in prompts: | |
| prompt = prompt.strip() | |
| parts = prompt.split('_') | |
| valid_combinations.add((parts[0], parts[1], parts[2])) | |
| # 准备下拉框选项 | |
| periods = sorted({p for p, _, _ in valid_combinations}) | |
| composers = sorted({c for _, c, _ in valid_combinations}) | |
| instruments = sorted({i for _, _, i in valid_combinations}) | |
| # 动态更新作曲家、乐器下拉选项 | |
| def update_components(period, composer): | |
| if not period: | |
| return [ | |
| gr.update(choices=[], value=None, interactive=False), | |
| gr.update(choices=[], value=None, interactive=False) | |
| ] | |
| valid_composers = sorted({c for p, c, _ in valid_combinations if p == period}) | |
| valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else [] | |
| return [ | |
| gr.update( | |
| choices=valid_composers, | |
| value=composer if composer in valid_composers else None, | |
| interactive=True | |
| ), | |
| gr.update( | |
| choices=valid_instruments, | |
| value=None, | |
| interactive=bool(valid_instruments) | |
| ) | |
| ] | |
| # 自定义实时流,用于把模型推理过程输出到前端 | |
| class RealtimeStream(TextIOBase): | |
| def __init__(self, queue): | |
| self.queue = queue | |
| def write(self, text): | |
| self.queue.put(text) | |
| return len(text) | |
| def convert_files(abc_content, period, composer, instrumentation): | |
| if not all([period, composer, instrumentation]): | |
| raise gr.Error("Please complete a valid generation first before saving") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| prompt_str = f"{period}_{composer}_{instrumentation}" | |
| filename_base = f"{timestamp}_{prompt_str}" | |
| abc_filename = f"{filename_base}.abc" | |
| with open(abc_filename, "w", encoding="utf-8") as f: | |
| f.write(abc_content) | |
| # instrumentation replacement | |
| postprocessed_inst_abc = postprocess_inst_names(abc_content) | |
| filename_base_postinst = f"{filename_base}_postinst" | |
| with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f: | |
| f.write(postprocessed_inst_abc) | |
| # 转换文件 | |
| file_paths = {'abc': abc_filename} | |
| try: | |
| # abc2xml | |
| abc2xml(filename_base) | |
| abc2xml(filename_base_postinst) | |
| # xml2pdf | |
| xml2(filename_base, 'pdf') | |
| # xml2mid | |
| xml2(filename_base, 'mid') | |
| xml2(filename_base_postinst, 'mid') | |
| # xml2wav | |
| xml2(filename_base, 'wav') | |
| xml2(filename_base_postinst, 'wav') | |
| # 将PDF转为图片 | |
| images = pdf2img(filename_base) | |
| for i, image in enumerate(images): | |
| image.save(f"{filename_base}_page_{i+1}.png", "PNG") | |
| file_paths.update({ | |
| 'xml': f"{filename_base_postinst}.xml", | |
| 'pdf': f"{filename_base}.pdf", | |
| 'mid': f"{filename_base_postinst}.mid", | |
| 'wav': f"{filename_base_postinst}.wav", | |
| 'pages': len(images), | |
| 'current_page': 0, | |
| 'base': filename_base | |
| }) | |
| except Exception as e: | |
| raise gr.Error(f"文件处理失败: {str(e)}") | |
| return file_paths | |
| # 翻页控制函数 | |
| def update_page(direction, data): | |
| """ | |
| data 里面包含了 'pages','current_page','base' 三个关键信息 | |
| """ | |
| if not data: | |
| return None, gr.update(interactive=False), gr.update(interactive=False), data | |
| if direction == "prev" and data['current_page'] > 0: | |
| data['current_page'] -= 1 | |
| elif direction == "next" and data['current_page'] < data['pages'] - 1: | |
| data['current_page'] += 1 | |
| current_page_index = data['current_page'] | |
| # 更新图片路径 | |
| new_image = f"{data['base']}_page_{current_page_index+1}.png" | |
| # 当 current_page==0 时,prev_btn 不可用;当 current_page==pages-1 时,next_btn 不可用 | |
| prev_btn_state = gr.update(interactive=(current_page_index > 0)) | |
| next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1)) | |
| return new_image, prev_btn_state, next_btn_state, data | |
| def generate_music(period, composer, instrumentation): | |
| """ | |
| 需要保证每次 yield 的返回值数量一致。 | |
| 我们这里准备返回 5 个值,对应: | |
| 1) process_output (中间推理信息) | |
| 2) final_output (最终 ABC) | |
| 3) pdf_image (PDF 第一页对应的 png 路径) | |
| 4) audio_player (WAV 路径) | |
| 5) pdf_state (翻页用的 state) | |
| """ | |
| if (period, composer, instrumentation) not in valid_combinations: | |
| # 如果组合非法,直接抛出错误 | |
| raise gr.Error("Invalid prompt combination! Please re-select from the period options") | |
| # # Ensure model weights were downloaded successfully | |
| # if not os.path.exists(model_weights_path): | |
| # raise gr.Error(f"Model weights not available at {model_weights_path}") | |
| output_queue = queue.Queue() | |
| original_stdout = sys.stdout | |
| sys.stdout = RealtimeStream(output_queue) | |
| result_container = [] | |
| def run_inference(): | |
| try: | |
| # 使用下载的模型权重路径进行推理 | |
| result = inference_patch(period, composer, instrumentation) | |
| result_container.append(result) | |
| finally: | |
| sys.stdout = original_stdout | |
| thread = threading.Thread(target=run_inference) | |
| thread.start() | |
| process_output = "" | |
| final_output_abc = "" | |
| pdf_image = None | |
| audio_file = None | |
| pdf_state = None | |
| # 先持续读中间输出 | |
| while thread.is_alive(): | |
| try: | |
| text = output_queue.get(timeout=0.1) | |
| process_output += text | |
| # 暂时没有最终 ABC,还没有转文件 | |
| yield process_output, final_output_abc, pdf_image, audio_file, pdf_state | |
| except queue.Empty: | |
| continue | |
| # 线程结束后,把剩余的队列都拿出来 | |
| while not output_queue.empty(): | |
| text = output_queue.get() | |
| process_output += text | |
| # 最终推理结果 | |
| final_result = result_container[0] if result_container else "" | |
| # 显示转换文件的提示 | |
| final_output_abc = "Converting files..." | |
| yield process_output, final_output_abc, pdf_image, audio_file, pdf_state | |
| # 做文件转换 | |
| try: | |
| file_paths = convert_files(final_result, period, composer, instrumentation) | |
| final_output_abc = final_result | |
| # 拿到第一张图片和 wav 文件 | |
| if file_paths['pages'] > 0: | |
| pdf_image = f"{file_paths['base']}_page_1.png" | |
| audio_file = file_paths['wav'] | |
| pdf_state = file_paths # 直接把转换后的信息字典拿来存到 state | |
| except Exception as e: | |
| # 如果失败了,把错误信息返回到输出框 | |
| yield process_output, f"Error converting files: {str(e)}", None, None, None | |
| return | |
| # 最后一次 yield,带上所有信息 | |
| yield process_output, final_output_abc, pdf_image, audio_file, pdf_state | |
| def get_file(file_type, period, composer, instrumentation): | |
| """ | |
| 返回本地的指定类型文件,用于 Gradio 下载 | |
| """ | |
| # 这里其实需要你根据先前保存下来的具体文件路径来返回,演示时可以简化 | |
| # 如果是按 timestamp 去匹配,可以把转换的文件都存在某个目录下再拿最新的 | |
| # 这里仅做示例: | |
| possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')] | |
| if not possible_files: | |
| return None | |
| # 简单返回最新的 | |
| possible_files.sort(key=os.path.getmtime) | |
| return possible_files[-1] | |
| css = """ | |
| /* 紧凑按钮样式 */ | |
| button[size="sm"] { | |
| padding: 4px 8px !important; | |
| margin: 2px !important; | |
| min-width: 60px; | |
| } | |
| /* PDF预览区 */ | |
| #pdf-preview { | |
| border-radius: 8px; /* 圆角 */ | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* 阴影 */ | |
| } | |
| .page-btn { | |
| padding: 12px !important; /* 增大点击区域 */ | |
| margin: auto !important; /* 垂直居中 */ | |
| } | |
| /* 按钮悬停效果 */ | |
| .page-btn:hover { | |
| background: #f0f0f0 !important; | |
| transform: scale(1.05); | |
| } | |
| /* 布局调整 */ | |
| .gr-row { | |
| gap: 10px !important; /* 元素间距 */ | |
| } | |
| /* 音频播放器 */ | |
| .audio-panel { | |
| margin-top: 15px !important; | |
| max-width: 400px; | |
| } | |
| #audio-preview audio { | |
| height: 200px !important; | |
| } | |
| /* 保存功能区 */ | |
| .save-as-row { | |
| margin-top: 15px; | |
| padding: 10px; | |
| border-top: 1px solid #eee; | |
| } | |
| .save-as-label { | |
| font-weight: bold; | |
| margin-right: 10px; | |
| align-self: center; | |
| } | |
| .save-buttons { | |
| gap: 5px; /* 按钮间距 */ | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("## NotaGen") | |
| # 用于保存 PDF 页数、当前页等信息 | |
| pdf_state = gr.State() | |
| with gr.Column(): | |
| with gr.Row(): | |
| # 左侧栏 | |
| with gr.Column(): | |
| with gr.Row(): | |
| period_dd = gr.Dropdown( | |
| choices=periods, | |
| value=None, | |
| label="Period", | |
| interactive=True | |
| ) | |
| composer_dd = gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="Composer", | |
| interactive=False | |
| ) | |
| instrument_dd = gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="Instrumentation", | |
| interactive=False | |
| ) | |
| generate_btn = gr.Button("Generate!", variant="primary") | |
| process_output = gr.Textbox( | |
| label="Generation process", | |
| interactive=False, | |
| lines=2, | |
| max_lines=2, | |
| placeholder="Generation progress will be shown here..." | |
| ) | |
| final_output = gr.Textbox( | |
| label="Post-processed ABC notation scores", | |
| interactive=True, | |
| lines=8, | |
| max_lines=8, | |
| placeholder="Post-processed ABC scores will be shown here..." | |
| ) | |
| # 音频播放 | |
| audio_player = gr.Audio( | |
| label="Audio Preview", | |
| format="wav", | |
| interactive=False, | |
| # container=False, | |
| # elem_id="audio-preview" | |
| ) | |
| # 右侧栏 | |
| with gr.Column(): | |
| # 图片容器 | |
| pdf_image = gr.Image( | |
| label="Sheet Music Preview", | |
| show_label=False, | |
| height=650, | |
| type="filepath", | |
| elem_id="pdf-preview", | |
| interactive=False, | |
| show_download_button=False | |
| ) | |
| # 翻页按钮 | |
| with gr.Row(): | |
| prev_btn = gr.Button( | |
| "⬅️ Last Page", | |
| variant="secondary", | |
| size="sm", | |
| elem_classes="page-btn" | |
| ) | |
| next_btn = gr.Button( | |
| "Next Page ➡️", | |
| variant="secondary", | |
| size="sm", | |
| elem_classes="page-btn" | |
| ) | |
| # 按钮组 | |
| with gr.Row(): | |
| gr.Markdown("**Save As: (Scroll down to get the link)**") | |
| save_abc = gr.Button("🅰️ ABC", variant="secondary", size="sm") | |
| save_xml = gr.Button("🎼 XML", variant="secondary", size="sm") | |
| save_pdf = gr.Button("📑 PDF", variant="secondary", size="sm") | |
| save_mid = gr.Button("🎹 MIDI", variant="secondary", size="sm") | |
| save_wav = gr.Button("🎧 WAV", variant="secondary", size="sm") | |
| # save_status = gr.Textbox( | |
| # label="Save Status", | |
| # interactive=False, | |
| # visible=True, | |
| # max_lines=1 | |
| # ) | |
| # 下拉框联动 | |
| period_dd.change( | |
| update_components, | |
| inputs=[period_dd, composer_dd], | |
| outputs=[composer_dd, instrument_dd] | |
| ) | |
| composer_dd.change( | |
| update_components, | |
| inputs=[period_dd, composer_dd], | |
| outputs=[composer_dd, instrument_dd] | |
| ) | |
| # 点击生成按钮,注意 outputs 要和 generate_music 里每次 yield 保持一致 | |
| generate_btn.click( | |
| generate_music, | |
| inputs=[period_dd, composer_dd, instrument_dd], | |
| outputs=[process_output, final_output, pdf_image, audio_player, pdf_state] | |
| ) | |
| # 翻页 | |
| prev_signal = gr.Textbox(value="prev", visible=False) | |
| next_signal = gr.Textbox(value="next", visible=False) | |
| prev_btn.click( | |
| update_page, | |
| inputs=[prev_signal, pdf_state], # ✅ 使用组件 | |
| outputs=[pdf_image, prev_btn, next_btn, pdf_state] | |
| ) | |
| next_btn.click( | |
| update_page, | |
| inputs=[next_signal, pdf_state], # ✅ 使用组件 | |
| outputs=[pdf_image, prev_btn, next_btn, pdf_state] | |
| ) | |
| # 文件保存按钮 | |
| save_abc.click( | |
| lambda state: state.get('abc') if state else None, | |
| inputs=[pdf_state], | |
| outputs=gr.File(label="abc", visible=True) | |
| ) | |
| save_xml.click( | |
| lambda state: state.get('xml') if state else None, | |
| inputs=[pdf_state], | |
| outputs=gr.File(label="xml", visible=True) | |
| ) | |
| save_pdf.click( | |
| lambda state: state.get('pdf') if state else None, | |
| inputs=[pdf_state], | |
| outputs=gr.File(label="pdf", visible=True) | |
| ) | |
| save_mid.click( | |
| lambda state: state.get('mid') if state else None, | |
| inputs=[pdf_state], | |
| outputs=gr.File(label="midi", visible=True) | |
| ) | |
| save_wav.click( | |
| lambda state: state.get('wav') if state else None, | |
| inputs=[pdf_state], | |
| outputs=gr.File(label="wav", visible=True) | |
| ) | |
| if __name__ == "__main__": | |
| # Determine if we're running on HF Spaces | |
| is_spaces = os.environ.get('SPACE_ID') is not None | |
| # For Spaces, we need to use the PORT environment variable | |
| if is_spaces: | |
| port = int(os.environ.get('PORT', 7860)) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |
| else: | |
| # For local development | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |