epirag / server.py
RohanB67's picture
add feature
189df32
"""
EpiRAG - server.py
------------------
Flask server with:
- /api/query — standard JSON response
- /api/query/stream — SSE streaming (live debate events)
- /api/stats — corpus stats
- /api/metrics — session performance metrics
- /performance — performance dashboard page
"""
import os
import time
import json
import queue
import threading
import chromadb
from flask import Flask, jsonify, request, send_from_directory, Response, stream_with_context
from flask_cors import CORS
from query import rag_query, set_components
from dotenv import load_dotenv
load_dotenv()
app = Flask(__name__, static_folder="static")
CORS(app)
COLLECTION_NAME = "epirag"
IS_CLOUD = os.environ.get("EPIRAG_ENV", "").lower() == "cloud"
# -- Session metrics ------------------------------------------------------------
SESSION_METRICS = {
"queries_total": 0,
"queries_local": 0,
"queries_web": 0,
"queries_hybrid": 0,
"queries_debate": 0,
"latencies_ms": [],
"started_at": time.time(),
}
def record_metric(result, elapsed_ms):
SESSION_METRICS["queries_total"] += 1
SESSION_METRICS["latencies_ms"].append(elapsed_ms)
mode = result.get("mode", "")
if mode == "local": SESSION_METRICS["queries_local"] += 1
if mode == "web": SESSION_METRICS["queries_web"] += 1
if mode == "hybrid": SESSION_METRICS["queries_hybrid"] += 1
if result.get("is_debate"): SESSION_METRICS["queries_debate"] += 1
# -- Corpus startup------------------------------------------------------------
_collection = None
_embedder = None
CORPUS_STATS = {}
def init_corpus():
global _collection, _embedder, CORPUS_STATS
if IS_CLOUD:
print("Cloud mode - building in-memory corpus", flush=True)
from ingest import build_collection_in_memory
_collection, _embedder = build_collection_in_memory()
else:
print("Local mode - loading from ./chroma_db/", flush=True)
from sentence_transformers import SentenceTransformer
client = chromadb.PersistentClient(path="./chroma_db")
_collection = client.get_collection(COLLECTION_NAME)
_embedder = SentenceTransformer("all-MiniLM-L6-v2")
set_components(_embedder, _collection)
count = _collection.count()
results = _collection.get(limit=count, include=["metadatas"])
papers = sorted(set(
m.get("paper_name", m.get("source", "Unknown"))
for m in results["metadatas"]
))
CORPUS_STATS.update({
"chunks": count,
"papers": len(papers),
"paperList": papers,
"status": "online",
"mode": "cloud (in-memory)" if IS_CLOUD else "local (persistent)"
})
print(f"Corpus ready: {count} chunks / {len(papers)} papers", flush=True)
init_corpus()
# -- Routes ------------------------------------------------------------
@app.route("/")
def index():
return send_from_directory("static", "index.html")
@app.route("/performance")
def performance():
return send_from_directory("static", "performance.html")
@app.route("/api/stats")
def stats():
return jsonify(CORPUS_STATS)
@app.route("/api/metrics")
def metrics():
lats = SESSION_METRICS["latencies_ms"]
avg = int(sum(lats) / len(lats)) if lats else 0
return jsonify({
**SESSION_METRICS,
"avg_latency_ms": avg,
"uptime_seconds": int(time.time() - SESSION_METRICS["started_at"]),
"latencies_ms": lats[-50:], # last 50 only
})
@app.route("/api/query", methods=["POST"])
def query():
data = request.json or {}
question = (data.get("question") or "").strip()
if not question:
return jsonify({"error": "No question provided"}), 400
groq_key = os.environ.get("GROQ_API_KEY")
tavily_key = os.environ.get("TAVILY_API_KEY")
hf_token = os.environ.get("HF_TOKEN")
if not groq_key:
return jsonify({"error": "GROQ_API_KEY not set on server"}), 500
start = time.time()
result = rag_query(
question,
groq_api_key = groq_key,
tavily_api_key = tavily_key,
hf_token = hf_token,
use_debate = bool(hf_token)
)
elapsed_ms = int((time.time() - start) * 1000)
record_metric(result, elapsed_ms)
return jsonify({
"answer": result["answer"],
"sources": result["sources"],
"mode": result["mode"],
"avg_sim": result["avg_sim"],
"latency_ms": elapsed_ms,
"tokens": len(result["answer"]) // 4,
"question": question,
"is_debate": result.get("is_debate", False),
"debate_rounds": result.get("debate_rounds", []),
"consensus": result.get("consensus", False),
"rounds_run": result.get("rounds_run", 0),
})
@app.route("/api/query/stream", methods=["POST"])
def query_stream():
"""
SSE endpoint. Streams debate events in real time, then sends final result.
Event types sent to browser:
data: {"type": "status", "text": "..."}
data: {"type": "round_start", "round": N}
data: {"type": "agent_done", "round": N, "name": "...", "color": "...", "text": "..."}
data: {"type": "synthesizing"}
data: {"type": "result", ...full result payload...}
data: {"type": "error", "text": "..."}
"""
data = request.json or {}
question = (data.get("question") or "").strip()
if not question:
return jsonify({"error": "No question provided"}), 400
groq_key = os.environ.get("GROQ_API_KEY")
tavily_key = os.environ.get("TAVILY_API_KEY")
hf_token = os.environ.get("HF_TOKEN")
event_queue = queue.Queue()
def callback(event):
event_queue.put(event)
def run_in_thread():
try:
start = time.time()
result = rag_query(
question,
groq_api_key = groq_key,
tavily_api_key = tavily_key,
hf_token = hf_token,
use_debate = bool(hf_token),
sse_callback = callback
)
elapsed_ms = int((time.time() - start) * 1000)
record_metric(result, elapsed_ms)
event_queue.put({
"type": "result",
"answer": result["answer"],
"sources": result["sources"],
"mode": result["mode"],
"avg_sim": result["avg_sim"],
"latency_ms": elapsed_ms,
"tokens": len(result["answer"]) // 4,
"is_debate": result.get("is_debate", False),
"debate_rounds": result.get("debate_rounds", []),
"consensus": result.get("consensus", False),
"rounds_run": result.get("rounds_run", 0),
})
except Exception as e:
event_queue.put({"type": "error", "text": str(e)})
finally:
event_queue.put(None) # sentinel
thread = threading.Thread(target=run_in_thread, daemon=True)
thread.start()
def generate():
yield "data: " + json.dumps({"type": "status", "text": "Retrieving context..."}) + "\n\n"
while True:
try:
event = event_queue.get(timeout=60)
except queue.Empty:
yield "data: " + json.dumps({"type": "error", "text": "Timeout"}) + "\n\n"
break
if event is None:
break
yield "data: " + json.dumps(event) + "\n\n"
return Response(
stream_with_context(generate()),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
}
)
@app.route("/api/health")
def health():
return jsonify({"status": "ok", "corpus": CORPUS_STATS.get("status", "unknown")})
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(debug=False, host="0.0.0.0", port=port, threaded=True)