| """ |
| EpiRAG - server.py |
| ------------------ |
| Flask server with: |
| - /api/query — standard JSON response |
| - /api/query/stream — SSE streaming (live debate events) |
| - /api/stats — corpus stats |
| - /api/metrics — session performance metrics |
| - /performance — performance dashboard page |
| """ |
|
|
| import os |
| import time |
| import json |
| import queue |
| import threading |
| import chromadb |
| from flask import Flask, jsonify, request, send_from_directory, Response, stream_with_context |
| from flask_cors import CORS |
| from query import rag_query, set_components |
|
|
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| app = Flask(__name__, static_folder="static") |
| CORS(app) |
|
|
| COLLECTION_NAME = "epirag" |
| IS_CLOUD = os.environ.get("EPIRAG_ENV", "").lower() == "cloud" |
|
|
| |
| SESSION_METRICS = { |
| "queries_total": 0, |
| "queries_local": 0, |
| "queries_web": 0, |
| "queries_hybrid": 0, |
| "queries_debate": 0, |
| "latencies_ms": [], |
| "started_at": time.time(), |
| } |
|
|
| def record_metric(result, elapsed_ms): |
| SESSION_METRICS["queries_total"] += 1 |
| SESSION_METRICS["latencies_ms"].append(elapsed_ms) |
| mode = result.get("mode", "") |
| if mode == "local": SESSION_METRICS["queries_local"] += 1 |
| if mode == "web": SESSION_METRICS["queries_web"] += 1 |
| if mode == "hybrid": SESSION_METRICS["queries_hybrid"] += 1 |
| if result.get("is_debate"): SESSION_METRICS["queries_debate"] += 1 |
|
|
|
|
| |
| _collection = None |
| _embedder = None |
| CORPUS_STATS = {} |
|
|
| def init_corpus(): |
| global _collection, _embedder, CORPUS_STATS |
| if IS_CLOUD: |
| print("Cloud mode - building in-memory corpus", flush=True) |
| from ingest import build_collection_in_memory |
| _collection, _embedder = build_collection_in_memory() |
| else: |
| print("Local mode - loading from ./chroma_db/", flush=True) |
| from sentence_transformers import SentenceTransformer |
| client = chromadb.PersistentClient(path="./chroma_db") |
| _collection = client.get_collection(COLLECTION_NAME) |
| _embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
| set_components(_embedder, _collection) |
|
|
| count = _collection.count() |
| results = _collection.get(limit=count, include=["metadatas"]) |
| papers = sorted(set( |
| m.get("paper_name", m.get("source", "Unknown")) |
| for m in results["metadatas"] |
| )) |
| CORPUS_STATS.update({ |
| "chunks": count, |
| "papers": len(papers), |
| "paperList": papers, |
| "status": "online", |
| "mode": "cloud (in-memory)" if IS_CLOUD else "local (persistent)" |
| }) |
| print(f"Corpus ready: {count} chunks / {len(papers)} papers", flush=True) |
|
|
|
|
| init_corpus() |
|
|
|
|
| |
| @app.route("/") |
| def index(): |
| return send_from_directory("static", "index.html") |
|
|
|
|
| @app.route("/performance") |
| def performance(): |
| return send_from_directory("static", "performance.html") |
|
|
|
|
| @app.route("/api/stats") |
| def stats(): |
| return jsonify(CORPUS_STATS) |
|
|
|
|
| @app.route("/api/metrics") |
| def metrics(): |
| lats = SESSION_METRICS["latencies_ms"] |
| avg = int(sum(lats) / len(lats)) if lats else 0 |
| return jsonify({ |
| **SESSION_METRICS, |
| "avg_latency_ms": avg, |
| "uptime_seconds": int(time.time() - SESSION_METRICS["started_at"]), |
| "latencies_ms": lats[-50:], |
| }) |
|
|
|
|
| @app.route("/api/query", methods=["POST"]) |
| def query(): |
| data = request.json or {} |
| question = (data.get("question") or "").strip() |
| if not question: |
| return jsonify({"error": "No question provided"}), 400 |
|
|
| groq_key = os.environ.get("GROQ_API_KEY") |
| tavily_key = os.environ.get("TAVILY_API_KEY") |
| hf_token = os.environ.get("HF_TOKEN") |
| if not groq_key: |
| return jsonify({"error": "GROQ_API_KEY not set on server"}), 500 |
|
|
| start = time.time() |
| result = rag_query( |
| question, |
| groq_api_key = groq_key, |
| tavily_api_key = tavily_key, |
| hf_token = hf_token, |
| use_debate = bool(hf_token) |
| ) |
| elapsed_ms = int((time.time() - start) * 1000) |
| record_metric(result, elapsed_ms) |
|
|
| return jsonify({ |
| "answer": result["answer"], |
| "sources": result["sources"], |
| "mode": result["mode"], |
| "avg_sim": result["avg_sim"], |
| "latency_ms": elapsed_ms, |
| "tokens": len(result["answer"]) // 4, |
| "question": question, |
| "is_debate": result.get("is_debate", False), |
| "debate_rounds": result.get("debate_rounds", []), |
| "consensus": result.get("consensus", False), |
| "rounds_run": result.get("rounds_run", 0), |
| }) |
|
|
|
|
| @app.route("/api/query/stream", methods=["POST"]) |
| def query_stream(): |
| """ |
| SSE endpoint. Streams debate events in real time, then sends final result. |
| |
| Event types sent to browser: |
| data: {"type": "status", "text": "..."} |
| data: {"type": "round_start", "round": N} |
| data: {"type": "agent_done", "round": N, "name": "...", "color": "...", "text": "..."} |
| data: {"type": "synthesizing"} |
| data: {"type": "result", ...full result payload...} |
| data: {"type": "error", "text": "..."} |
| """ |
| data = request.json or {} |
| question = (data.get("question") or "").strip() |
| if not question: |
| return jsonify({"error": "No question provided"}), 400 |
|
|
| groq_key = os.environ.get("GROQ_API_KEY") |
| tavily_key = os.environ.get("TAVILY_API_KEY") |
| hf_token = os.environ.get("HF_TOKEN") |
|
|
| event_queue = queue.Queue() |
|
|
| def callback(event): |
| event_queue.put(event) |
|
|
| def run_in_thread(): |
| try: |
| start = time.time() |
| result = rag_query( |
| question, |
| groq_api_key = groq_key, |
| tavily_api_key = tavily_key, |
| hf_token = hf_token, |
| use_debate = bool(hf_token), |
| sse_callback = callback |
| ) |
| elapsed_ms = int((time.time() - start) * 1000) |
| record_metric(result, elapsed_ms) |
| event_queue.put({ |
| "type": "result", |
| "answer": result["answer"], |
| "sources": result["sources"], |
| "mode": result["mode"], |
| "avg_sim": result["avg_sim"], |
| "latency_ms": elapsed_ms, |
| "tokens": len(result["answer"]) // 4, |
| "is_debate": result.get("is_debate", False), |
| "debate_rounds": result.get("debate_rounds", []), |
| "consensus": result.get("consensus", False), |
| "rounds_run": result.get("rounds_run", 0), |
| }) |
| except Exception as e: |
| event_queue.put({"type": "error", "text": str(e)}) |
| finally: |
| event_queue.put(None) |
|
|
| thread = threading.Thread(target=run_in_thread, daemon=True) |
| thread.start() |
|
|
| def generate(): |
| yield "data: " + json.dumps({"type": "status", "text": "Retrieving context..."}) + "\n\n" |
| while True: |
| try: |
| event = event_queue.get(timeout=60) |
| except queue.Empty: |
| yield "data: " + json.dumps({"type": "error", "text": "Timeout"}) + "\n\n" |
| break |
| if event is None: |
| break |
| yield "data: " + json.dumps(event) + "\n\n" |
|
|
| return Response( |
| stream_with_context(generate()), |
| mimetype="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "X-Accel-Buffering": "no", |
| } |
| ) |
|
|
|
|
| @app.route("/api/health") |
| def health(): |
| return jsonify({"status": "ok", "corpus": CORPUS_STATS.get("status", "unknown")}) |
|
|
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", 7860)) |
| app.run(debug=False, host="0.0.0.0", port=port, threaded=True) |