| import asyncio |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import streamlit as st |
| import yaml |
| from loguru import logger as _logger |
| import shutil |
| import uuid |
|
|
| from metagpt.const import METAGPT_ROOT |
| from metagpt.ext.spo.components.optimizer import PromptOptimizer |
| from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType |
|
|
|
|
| def get_user_workspace(): |
| if "user_id" not in st.session_state: |
| st.session_state.user_id = str(uuid.uuid4()) |
|
|
| workspace_dir = Path("workspace") / st.session_state.user_id |
| workspace_dir.mkdir(parents=True, exist_ok=True) |
| return workspace_dir |
|
|
|
|
| def cleanup_workspace(workspace_dir: Path) -> None: |
| try: |
| if workspace_dir.exists(): |
| shutil.rmtree(workspace_dir) |
| _logger.info(f"Cleaned up workspace directory: {workspace_dir}") |
| except Exception as e: |
| _logger.error(f"Error cleaning up workspace: {e}") |
|
|
|
|
| def get_template_path(template_name: str, is_new_template: bool = False) -> str: |
| """ |
| Get template file path |
| :param template_name: Name of the template |
| :param is_new_template: Whether it's a new template created by user |
| :return: Path object for the template file |
| """ |
|
|
| if is_new_template: |
| |
| if "user_id" not in st.session_state: |
| st.session_state.user_id = str(uuid.uuid4()) |
| user_settings_path = st.session_state.user_id |
| return f"{user_settings_path}/{template_name}.yaml" |
| else: |
| |
| return f"{template_name}.yaml" |
|
|
|
|
| def get_all_templates() -> List[str]: |
| """ |
| Get list of all available templates (both default and user-specific) |
| :return: List of template names |
| """ |
| settings_path = Path("metagpt/ext/spo/settings") |
|
|
| |
| templates = [f.stem for f in settings_path.glob("*.yaml")] |
|
|
| |
| if "user_id" in st.session_state: |
| user_path = settings_path / st.session_state.user_id |
| if user_path.exists(): |
| user_templates = [f"{st.session_state.user_id}/{f.stem}" for f in user_path.glob("*.yaml")] |
| templates.extend(user_templates) |
|
|
| return sorted(list(set(templates))) |
|
|
|
|
| def load_yaml_template(template_path: Path) -> Dict: |
| if template_path.exists(): |
| with open(template_path, "r", encoding="utf-8") as f: |
| return yaml.safe_load(f) |
| return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]} |
|
|
|
|
| def save_yaml_template(template_path: Path, data: Dict, is_new: bool) -> None: |
| |
| if is_new: |
| template_format = { |
| "prompt": str(data.get("prompt", "")), |
| "requirements": str(data.get("requirements", "")), |
| "count": data.get("count"), |
| "qa": [ |
| {"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()} |
| for qa in data.get("qa", []) |
| ], |
| } |
| |
| template_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(template_path, "w", encoding="utf-8") as f: |
| yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2) |
| else: |
| pass |
|
|
| def display_optimization_results(result_data): |
| for result in result_data: |
| round_num = result["round"] |
| success = result["succeed"] |
| prompt = result["prompt"] |
|
|
| with st.expander(f"轮次 {round_num} {':white_check_mark:' if success else ':x:'}"): |
| st.markdown("**提示词:**") |
| st.code(prompt, language="text") |
| st.markdown("<br>", unsafe_allow_html=True) |
|
|
| col1, col2 = st.columns(2) |
| with col1: |
| st.markdown(f"**状态:** {'成功 ✅ ' if success else '失败 ❌ '}") |
| with col2: |
| st.markdown(f"**令牌数:** {result['tokens']}") |
|
|
| st.markdown("**回答:**") |
| for idx, answer in enumerate(result["answers"]): |
| st.markdown(f"**问题 {idx + 1}:**") |
| st.text(answer["question"]) |
| st.markdown("**答案:**") |
| st.text(answer["answer"]) |
| st.markdown("---") |
|
|
| |
| success_count = sum(1 for r in result_data if r["succeed"]) |
| total_rounds = len(result_data) |
|
|
| st.markdown("### 总结") |
| col1, col2 = st.columns(2) |
| with col1: |
| st.metric("总轮次", total_rounds) |
| with col2: |
| st.metric("成功轮次", success_count) |
|
|
|
|
| def main(): |
| if "optimization_results" not in st.session_state: |
| st.session_state.optimization_results = [] |
|
|
| workspace_dir = get_user_workspace() |
|
|
| st.markdown( |
| """ |
| <div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px"> |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px"> |
| <h1 style="margin: 0;">SPO | 自监督提示词优化 🤖</h1> |
| </div> |
| <div style="display: flex; gap: 20px; align-items: center"> |
| <a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;"> |
| <img src="https://img.shields.io/badge/论文-PDF-red.svg" alt="论文"> |
| </a> |
| <a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;"> |
| <img src="https://img.shields.io/badge/GitHub-仓库-blue.svg" alt="GitHub"> |
| </a> |
| <span style="color: #666;">一个自监督提示词优化框架</span> |
| </div> |
| </div> |
| """, |
| unsafe_allow_html=True |
| ) |
|
|
| |
| with st.sidebar: |
| st.header("配置") |
|
|
| |
| settings_path = Path("metagpt/ext/spo/settings") |
| existing_templates = [f.stem for f in settings_path.glob("*.yaml")] |
| template_mode = st.radio("模板模式", ["使用现有", "创建新模板"]) |
|
|
| existing_templates = get_all_templates() |
|
|
| if template_mode == "使用现有": |
| template_name = st.selectbox("选择模板", existing_templates) |
| is_new_template = False |
| else: |
| template_name = st.text_input("新模板名称") |
| is_new_template = True |
|
|
| |
| st.subheader("LLM 设置") |
|
|
| base_url = st.text_input("基础 URL", value="https://api.example.com") |
| api_key = st.text_input("API 密钥", type="password") |
|
|
| opt_model = st.selectbox( |
| "优化模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 |
| ) |
| opt_temp = st.slider("优化温度", 0.0, 1.0, 0.7) |
|
|
| eval_model = st.selectbox( |
| "评估模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 |
| ) |
| eval_temp = st.slider("评估温度", 0.0, 1.0, 0.3) |
|
|
| exec_model = st.selectbox( |
| "执行模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 |
| ) |
| exec_temp = st.slider("执行温度", 0.0, 1.0, 0.0) |
|
|
| |
| st.subheader("优化器设置") |
| initial_round = st.number_input("初始轮次", 1, 100, 1) |
| max_rounds = st.number_input("最大轮次", 1, 100, 10) |
|
|
| |
| st.header("模板配置") |
|
|
| if template_name: |
| template_real_name = get_template_path(template_name, is_new_template) |
| settings_path = Path("metagpt/ext/spo/settings") |
|
|
| template_path = settings_path / template_real_name |
|
|
| template_data = load_yaml_template(template_path) |
|
|
| if "current_template" not in st.session_state or st.session_state.current_template != template_name: |
| st.session_state.current_template = template_name |
| st.session_state.qas = template_data.get("qa", []) |
|
|
| |
| prompt = st.text_area("提示词", template_data.get("prompt", ""), height=100) |
| requirements = st.text_area("要求", template_data.get("requirements", ""), height=100) |
|
|
| |
| st.subheader("问答示例") |
|
|
| |
| if st.button("添加新问答"): |
| st.session_state.qas.append({"question": "", "answer": ""}) |
|
|
| |
| new_qas = [] |
| for i in range(len(st.session_state.qas)): |
| st.markdown(f"**问答 #{i + 1}**") |
| col1, col2, col3 = st.columns([45, 45, 10]) |
|
|
| with col1: |
| question = st.text_area( |
| f"问题 {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100 |
| ) |
| with col2: |
| answer = st.text_area( |
| f"答案 {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100 |
| ) |
| with col3: |
| if st.button("🗑️", key=f"delete_{i}"): |
| st.session_state.qas.pop(i) |
| st.rerun() |
|
|
| new_qas.append({"question": question, "answer": answer}) |
|
|
| |
| if st.button("保存模板"): |
| new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas} |
|
|
| save_yaml_template(template_path, new_template_data, is_new_template) |
|
|
| st.session_state.qas = new_qas |
| st.success(f"模板已保存到 {template_path}") |
|
|
| st.subheader("当前模板预览") |
| preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt} |
| st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml") |
|
|
| st.subheader("优化日志") |
| log_container = st.empty() |
|
|
| class StreamlitSink: |
| def write(self, message): |
| current_logs = st.session_state.get("logs", []) |
| current_logs.append(message.strip()) |
| st.session_state.logs = current_logs |
|
|
| log_container.code("\n".join(current_logs), language="plaintext") |
|
|
| streamlit_sink = StreamlitSink() |
| _logger.remove() |
|
|
| def prompt_optimizer_filter(record): |
| return "optimizer" in record["name"].lower() |
|
|
| _logger.add( |
| streamlit_sink.write, |
| format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", |
| filter=prompt_optimizer_filter, |
| ) |
| _logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG") |
|
|
| |
| if st.button("开始优化"): |
| try: |
| |
| SPO_LLM.initialize( |
| optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url, |
| "api_key": api_key}, |
| evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url, |
| "api_key": api_key}, |
| execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url, |
| "api_key": api_key}, |
| ) |
|
|
| |
| optimizer = PromptOptimizer( |
| optimized_path=str(workspace_dir), |
| initial_round=initial_round, |
| max_rounds=max_rounds, |
| template=f"{template_real_name}", |
| name=template_name, |
| ) |
|
|
| |
| with st.spinner("Optimizing prompts..."): |
| optimizer.optimize() |
|
|
| st.success("优化完成!") |
| st.header("优化结果") |
| prompt_path = optimizer.root_path / "prompts" |
| result_data = optimizer.data_utils.load_results(prompt_path) |
|
|
| st.session_state.optimization_results = result_data |
|
|
| except Exception as e: |
| st.error(f"发生错误:{str(e)}") |
| _logger.error(f"优化过程中出错:{str(e)}") |
|
|
| if st.session_state.optimization_results: |
| st.header("优化结果") |
| display_optimization_results(st.session_state.optimization_results) |
|
|
| st.markdown("---") |
| st.subheader("测试优化后的提示词") |
| col1, col2 = st.columns(2) |
|
|
| with col1: |
| test_prompt = st.text_area("优化后的提示词", value="", height=200, key="test_prompt") |
|
|
| with col2: |
| test_question = st.text_area("你的问题", value="", height=200, key="test_question") |
|
|
| if st.button("测试提示词"): |
| if test_prompt and test_question: |
| try: |
| with st.spinner("正在生成回答..."): |
| SPO_LLM.initialize( |
| optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url, |
| "api_key": api_key}, |
| evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url, |
| "api_key": api_key}, |
| execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url, |
| "api_key": api_key}, |
| ) |
|
|
| llm = SPO_LLM.get_instance() |
| messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}] |
|
|
| async def get_response(): |
| return await llm.responser(request_type=RequestType.EXECUTE, messages=messages) |
|
|
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| response = loop.run_until_complete(get_response()) |
| finally: |
| loop.close() |
|
|
| st.subheader("回答:") |
| st.markdown(response) |
|
|
| except Exception as e: |
| st.error(f"生成回答时出错:{str(e)}") |
| else: |
| st.warning("请输入提示词和问题。") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|