import os import json import numpy as np from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import faiss from sentence_transformers import SentenceTransformer from transformers import pipeline, GenerationConfig from rank_bm25 import BM25Okapi app = FastAPI(title="NDPA RAG System") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables to hold models and data chunks = [] index = None embedding_model = None bm25 = None text_generator = None generation_config = None @app.on_event("startup") def load_models_and_data(): global chunks, index, embedding_model, bm25, text_generator, generation_config print("Loading chunks.json...") try: with open("chunks.json", "r", encoding="utf-8") as f: chunks = json.load(f) except Exception as e: print(f"Error loading chunks.json: {e}. Make sure to run save_data.py first.") chunks = [] print("Loading FAISS index...") try: index = faiss.read_index("ndpa_faiss.index") except Exception as e: print(f"Error loading FAISS index: {e}") print("Initializing BM25...") if chunks: tokenized_chunks = [chunk.split(" ") for chunk in chunks] bm25 = BM25Okapi(tokenized_chunks) print("Loading SentenceTransformer model...") embedding_model = SentenceTransformer('all-MiniLM-L6-v2') print("Loading TinyLlama text generator locally (this might take a minute)...") # Setup generation config to avoid memory/timeout issues if possible generation_config = GenerationConfig( max_new_tokens=200, do_sample=False ) text_generator = pipeline( "text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1 # CPU ) print("Startup complete!") def hybrid_retrieve(query, top_k=5): # Dense retrieval query_embedding = embedding_model.encode([query]) query_embedding = query_embedding.astype("float32") distances, dense_indices = index.search(query_embedding, top_k) dense_results = [chunks[idx] for idx in dense_indices[0]] # BM25 retrieval tokenized_query = query.split(" ") bm25_scores = bm25.get_scores(tokenized_query) bm25_indices = np.argsort(bm25_scores)[::-1][:top_k] bm25_results = [chunks[idx] for idx in bm25_indices] # Merged Result merged = list(dict.fromkeys(dense_results + bm25_results)) return merged[:top_k] def build_prompt(query, contexts): context_text = "\n\n".join(contexts) prompt = f"""<|system|> You are a legal assistant specialized in the Nigerian Data Protection Act 2023. Answer ONLY using the provided context. If the answer is not in the context, say: 'I could not find the answer in the provided document.' <|user|> Context: {context_text} Question: {query} <|assistant|> """ return prompt class QueryRequest(BaseModel): query: str @app.post("/ask") def ask_question(request: QueryRequest): if not chunks or index is None or text_generator is None: return {"error": "System is not fully initialized. Check server logs."} query = request.query contexts = hybrid_retrieve(query) prompt = build_prompt(query, contexts) response = text_generator( prompt, generation_config=generation_config, clean_up_tokenization_spaces=False ) generated_text = response[0]["generated_text"] # Extract only the assistant's response part answer = generated_text.split("<|assistant|>\n")[-1].strip() return { "query": query, "answer": answer, "sources": contexts } # HTML UI HTML_CONTENT = """
Ask any question about the Nigerian Data Protection Act 2023