Chidiebere commited on
Commit
3851d3d
·
verified ·
1 Parent(s): bb38fa7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +237 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from fastapi import FastAPI, Request, Form
5
+ from fastapi.responses import HTMLResponse
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ import faiss
9
+ from sentence_transformers import SentenceTransformer
10
+ from transformers import pipeline, GenerationConfig
11
+ from rank_bm25 import BM25Okapi
12
+
13
+ app = FastAPI(title="NDPA RAG System")
14
+
15
+ # Add CORS middleware
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # Global variables to hold models and data
25
+ chunks = []
26
+ index = None
27
+ embedding_model = None
28
+ bm25 = None
29
+ text_generator = None
30
+ generation_config = None
31
+
32
+ @app.on_event("startup")
33
+ def load_models_and_data():
34
+ global chunks, index, embedding_model, bm25, text_generator, generation_config
35
+
36
+ print("Loading chunks.json...")
37
+ try:
38
+ with open("chunks.json", "r", encoding="utf-8") as f:
39
+ chunks = json.load(f)
40
+ except Exception as e:
41
+ print(f"Error loading chunks.json: {e}. Make sure to run save_data.py first.")
42
+ chunks = []
43
+
44
+ print("Loading FAISS index...")
45
+ try:
46
+ index = faiss.read_index("ndpa_faiss.index")
47
+ except Exception as e:
48
+ print(f"Error loading FAISS index: {e}")
49
+
50
+ print("Initializing BM25...")
51
+ if chunks:
52
+ tokenized_chunks = [chunk.split(" ") for chunk in chunks]
53
+ bm25 = BM25Okapi(tokenized_chunks)
54
+
55
+ print("Loading SentenceTransformer model...")
56
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
57
+
58
+ print("Loading TinyLlama text generator locally (this might take a minute)...")
59
+ # Setup generation config to avoid memory/timeout issues if possible
60
+ generation_config = GenerationConfig(
61
+ max_new_tokens=200,
62
+ do_sample=False
63
+ )
64
+
65
+ text_generator = pipeline(
66
+ "text-generation",
67
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
68
+ device=-1 # CPU
69
+ )
70
+ print("Startup complete!")
71
+
72
+ def hybrid_retrieve(query, top_k=5):
73
+ # Dense retrieval
74
+ query_embedding = embedding_model.encode([query])
75
+ query_embedding = query_embedding.astype("float32")
76
+
77
+ distances, dense_indices = index.search(query_embedding, top_k)
78
+ dense_results = [chunks[idx] for idx in dense_indices[0]]
79
+
80
+ # BM25 retrieval
81
+ tokenized_query = query.split(" ")
82
+ bm25_scores = bm25.get_scores(tokenized_query)
83
+ bm25_indices = np.argsort(bm25_scores)[::-1][:top_k]
84
+ bm25_results = [chunks[idx] for idx in bm25_indices]
85
+
86
+ # Merged Result
87
+ merged = list(dict.fromkeys(dense_results + bm25_results))
88
+ return merged[:top_k]
89
+
90
+ def build_prompt(query, contexts):
91
+ context_text = "\n\n".join(contexts)
92
+ prompt = f"""<|system|>
93
+ You are a legal assistant specialized in the Nigerian Data Protection Act 2023. Answer ONLY using the provided context. If the answer is not in the context, say: 'I could not find the answer in the provided document.'</s>
94
+ <|user|>
95
+ Context:
96
+ {context_text}
97
+
98
+ Question:
99
+ {query}</s>
100
+ <|assistant|>
101
+ """
102
+ return prompt
103
+
104
+ class QueryRequest(BaseModel):
105
+ query: str
106
+
107
+ @app.post("/ask")
108
+ def ask_question(request: QueryRequest):
109
+ if not chunks or index is None or text_generator is None:
110
+ return {"error": "System is not fully initialized. Check server logs."}
111
+
112
+ query = request.query
113
+ contexts = hybrid_retrieve(query)
114
+ prompt = build_prompt(query, contexts)
115
+
116
+ response = text_generator(
117
+ prompt,
118
+ generation_config=generation_config,
119
+ clean_up_tokenization_spaces=False
120
+ )
121
+
122
+ generated_text = response[0]["generated_text"]
123
+
124
+ # Extract only the assistant's response part
125
+ answer = generated_text.split("<|assistant|>\n")[-1].strip()
126
+
127
+ return {
128
+ "query": query,
129
+ "answer": answer,
130
+ "sources": contexts
131
+ }
132
+
133
+ # HTML UI
134
+ HTML_CONTENT = """
135
+ <!DOCTYPE html>
136
+ <html lang="en">
137
+ <head>
138
+ <meta charset="UTF-8">
139
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
140
+ <title>NDPA RAG System</title>
141
+ <style>
142
+ body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background-color: #f9fafb; color: #111827; }
143
+ h1 { color: #2563eb; text-align: center; }
144
+ .container { background-color: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.05); }
145
+ .chat-box { height: 400px; overflow-y: auto; border: 1px solid #e5e7eb; border-radius: 8px; padding: 15px; margin-bottom: 20px; display: flex; flex-direction: column; gap: 15px; background-color: #f3f4f6; }
146
+ .message { padding: 12px 16px; border-radius: 8px; max-width: 80%; line-height: 1.5; }
147
+ .user-message { background-color: #2563eb; color: white; align-self: flex-end; border-bottom-right-radius: 0; }
148
+ .bot-message { background-color: white; color: #1f2937; align-self: flex-start; border-bottom-left-radius: 0; border: 1px solid #e5e7eb; box-shadow: 0 1px 2px rgba(0,0,0,0.05); }
149
+ .input-group { display: flex; gap: 10px; }
150
+ input[type="text"] { flex: 1; padding: 12px; border: 1px solid #d1d5db; border-radius: 8px; outline: none; font-size: 16px; }
151
+ input[type="text"]:focus { border-color: #2563eb; }
152
+ button { padding: 12px 24px; background-color: #2563eb; color: white; border: none; border-radius: 8px; cursor: pointer; font-size: 16px; font-weight: 500; transition: background-color 0.2s; }
153
+ button:hover { background-color: #1d4ed8; }
154
+ button:disabled { background-color: #93c5fd; cursor: not-allowed; }
155
+ .loading { font-size: 14px; color: #6b7280; text-align: center; display: none; margin-top: 10px; }
156
+ </style>
157
+ </head>
158
+ <body>
159
+ <h1>NDPA 2023 Legal Assistant</h1>
160
+ <div class="container">
161
+ <p style="text-align: center; color: #4b5563; margin-bottom: 20px;">Ask any question about the Nigerian Data Protection Act 2023</p>
162
+ <div class="chat-box" id="chatBox">
163
+ <div class="message bot-message">Hello! I am an AI legal assistant trained on the Nigerian Data Protection Act (NDPA) 2023. What would you like to know?</div>
164
+ </div>
165
+ <div class="input-group">
166
+ <input type="text" id="queryInput" placeholder="E.g., What are the rights of a data subject?" onkeypress="handleKeyPress(event)">
167
+ <button id="sendBtn" onclick="askQuestion()">Ask</button>
168
+ </div>
169
+ <div class="loading" id="loadingIndicator">Generating answer... this might take a moment. (Using local TinyLlama, please be patient)</div>
170
+ </div>
171
+
172
+ <script>
173
+ async function askQuestion() {
174
+ const queryInput = document.getElementById('queryInput');
175
+ const chatBox = document.getElementById('chatBox');
176
+ const sendBtn = document.getElementById('sendBtn');
177
+ const loadingIndicator = document.getElementById('loadingIndicator');
178
+
179
+ const query = queryInput.value.trim();
180
+ if (!query) return;
181
+
182
+ // Add user message
183
+ appendMessage(query, 'user-message');
184
+ queryInput.value = '';
185
+
186
+ // Disable input and show loading
187
+ queryInput.disabled = true;
188
+ sendBtn.disabled = true;
189
+ loadingIndicator.style.display = 'block';
190
+
191
+ try {
192
+ const response = await fetch('/ask', {
193
+ method: 'POST',
194
+ headers: { 'Content-Type': 'application/json' },
195
+ body: JSON.stringify({ query: query })
196
+ });
197
+
198
+ const data = await response.json();
199
+
200
+ if (data.error) {
201
+ appendMessage("Error: " + data.error, 'bot-message');
202
+ } else {
203
+ appendMessage(data.answer, 'bot-message');
204
+ }
205
+ } catch (error) {
206
+ appendMessage("Error connecting to the server.", 'bot-message');
207
+ } finally {
208
+ // Enable input and hide loading
209
+ queryInput.disabled = false;
210
+ sendBtn.disabled = false;
211
+ loadingIndicator.style.display = 'none';
212
+ queryInput.focus();
213
+ }
214
+ }
215
+
216
+ function appendMessage(text, className) {
217
+ const chatBox = document.getElementById('chatBox');
218
+ const msgDiv = document.createElement('div');
219
+ msgDiv.className = `message ${className}`;
220
+ msgDiv.textContent = text;
221
+ chatBox.appendChild(msgDiv);
222
+ chatBox.scrollTop = chatBox.scrollHeight;
223
+ }
224
+
225
+ function handleKeyPress(event) {
226
+ if (event.key === 'Enter') {
227
+ askQuestion();
228
+ }
229
+ }
230
+ </script>
231
+ </body>
232
+ </html>
233
+ """
234
+
235
+ @app.get("/", response_class=HTMLResponse)
236
+ def read_root():
237
+ return HTML_CONTENT
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ sentence-transformers
4
+ faiss-cpu
5
+ transformers
6
+ torch
7
+ fastapi
8
+ uvicorn
9
+ rank-bm25
10
+ python-multipart
11
+ jinja2