sidharth-pm commited on
Commit
b260a50
·
verified ·
1 Parent(s): f64dbd2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import re
4
+ import tempfile
5
+ import pandas as pd
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ from database import init_database, get_schema, execute_query
8
+
9
+ # Model Setup
10
+ MODEL_ID = "microsoft/tapex-large-sql-execution"
11
+ tokenizer = None
12
+ sql_pipeline = None
13
+
14
+ def load_model():
15
+ global tokenizer, sql_pipeline
16
+ print("Loading SQLCoder-7B-2 ...")
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_ID,
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
21
+ device_map="auto",
22
+ trust_remote_code=True,
23
+ )
24
+ sql_pipeline = pipeline(
25
+ "text-generation",
26
+ model=model,
27
+ tokenizer=tokenizer,
28
+ max_new_tokens=512,
29
+ do_sample=False,
30
+ return_full_text=False,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ )
33
+ print("Model loaded.")
34
+
35
+
36
+ PROMPT_TEMPLATE = """### Task
37
+ Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
38
+
39
+ ### Database Schema
40
+ The query will run on a database with the following schema:
41
+ {schema}
42
+
43
+ ### Answer
44
+ Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
45
+ [SQL]
46
+ """
47
+
48
+ def build_prompt(question: str, schema: str) -> str:
49
+ return PROMPT_TEMPLATE.format(question=question, schema=schema)
50
+
51
+
52
+ def extract_sql(raw: str) -> str:
53
+ match = re.search(r"(SELECT[\s\S]+?);", raw, re.IGNORECASE)
54
+ if match:
55
+ return match.group(0).strip()
56
+ return raw.strip().split("[/SQL]")[0].strip()
57
+
58
+
59
+ def nl_to_sql_and_run(question: str, history: list):
60
+ if not question.strip():
61
+ yield history, "", gr.update(visible=False), gr.update(visible=False)
62
+ return
63
+
64
+ schema = get_schema()
65
+ prompt = build_prompt(question, schema)
66
+
67
+ yield history, "Generating SQL query...", gr.update(visible=False), gr.update(visible=False)
68
+
69
+ try:
70
+ output = sql_pipeline(prompt)[0]["generated_text"]
71
+ sql = extract_sql(output)
72
+ except Exception as e:
73
+ new_hist = history + [{"role": "user", "content": question},
74
+ {"role": "assistant", "content": f"Model error: {e}"}]
75
+ yield new_hist, "", gr.update(visible=False), gr.update(visible=False)
76
+ return
77
+
78
+ yield history, f"```sql\n{sql}\n```\n\nExecuting...", gr.update(visible=False), gr.update(visible=False)
79
+
80
+ try:
81
+ columns, rows = execute_query(sql)
82
+ except Exception as e:
83
+ answer = f"**Generated SQL:**\n```sql\n{sql}\n```\n\nExecution error: `{e}`"
84
+ new_hist = history + [{"role": "user", "content": question},
85
+ {"role": "assistant", "content": answer}]
86
+ yield new_hist, "", gr.update(visible=False), gr.update(visible=False)
87
+ return
88
+
89
+ if not rows:
90
+ result_md = "*(query returned no rows)*"
91
+ df = pd.DataFrame()
92
+ csv_path = None
93
+ else:
94
+ df = pd.DataFrame(rows, columns=columns)
95
+ result_md = df.to_markdown(index=False)
96
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="")
97
+ df.to_csv(tmp.name, index=False)
98
+ tmp.close()
99
+ csv_path = tmp.name
100
+
101
+ row_label = "rows" if len(rows) != 1 else "row"
102
+ answer = f"**Generated SQL:**\n```sql\n{sql}\n```\n\n**Results ({len(rows)} {row_label}):**\n{result_md}"
103
+ new_hist = history + [{"role": "user", "content": question},
104
+ {"role": "assistant", "content": answer}]
105
+
106
+ yield (
107
+ new_hist,
108
+ "",
109
+ gr.update(value=df, visible=bool(rows)),
110
+ gr.update(value=csv_path, visible=bool(rows)),
111
+ )
112
+
113
+
114
+ def view_schema():
115
+ return f"```sql\n{get_schema()}\n```"
116
+
117
+
118
+ CSS = """
119
+ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500&display=swap');
120
+
121
+ body, .gradio-container {
122
+ background: #0d0f14 !important;
123
+ font-family: 'DM Sans', sans-serif;
124
+ color: #e2e8f0;
125
+ }
126
+
127
+ .title-block {
128
+ text-align: center;
129
+ padding: 2rem 0 1rem;
130
+ }
131
+
132
+ .title-block h1 {
133
+ font-size: 2rem;
134
+ background: linear-gradient(135deg, #38bdf8, #818cf8);
135
+ -webkit-background-clip: text;
136
+ -webkit-text-fill-color: transparent;
137
+ font-family: 'Space Mono', monospace;
138
+ margin-bottom: 0.3rem;
139
+ }
140
+
141
+ .title-block p { color: #64748b; font-size: 0.95rem; }
142
+
143
+ .badge {
144
+ display: inline-block;
145
+ background: #1e2535;
146
+ border: 1px solid #2d3748;
147
+ border-radius: 20px;
148
+ padding: 2px 12px;
149
+ font-size: 0.75rem;
150
+ color: #94a3b8;
151
+ margin: 4px;
152
+ font-family: 'Space Mono', monospace;
153
+ }
154
+ """
155
+
156
+ EXAMPLE_QUERIES = [
157
+ "Show me all employees in Engineering with salary above 120000",
158
+ "Which department has the highest total salary budget?",
159
+ "List all active projects with their budgets",
160
+ "Who are the top 3 sales performers by total amount?",
161
+ "How many employees are in each department?",
162
+ "Show me all sales made in the East region in 2024",
163
+ ]
164
+
165
+
166
+ def create_app():
167
+ init_database()
168
+
169
+ with gr.Blocks(css=CSS, title="SQLCoder Studio") as demo:
170
+
171
+ gr.HTML("""
172
+ <div class="title-block">
173
+ <h1>SQLCoder Studio</h1>
174
+ <p>Natural language to SQL to Results &nbsp;|&nbsp; Powered by defog/sqlcoder-7b-2</p>
175
+ <div style="margin-top:0.8rem">
176
+ <span class="badge">employees</span>
177
+ <span class="badge">departments</span>
178
+ <span class="badge">projects</span>
179
+ <span class="badge">sales</span>
180
+ </div>
181
+ </div>
182
+ """)
183
+
184
+ with gr.Row():
185
+ with gr.Column(scale=3):
186
+ chatbot = gr.Chatbot(
187
+ label="Conversation",
188
+ height=460,
189
+ show_label=False,
190
+ render_markdown=True,
191
+ bubble_full_width=False,
192
+ type="messages",
193
+ )
194
+
195
+ with gr.Row():
196
+ question_input = gr.Textbox(
197
+ placeholder="Ask anything about the database...",
198
+ show_label=False,
199
+ scale=5,
200
+ lines=1,
201
+ )
202
+ submit_btn = gr.Button("RUN", variant="primary", scale=1)
203
+
204
+ with gr.Row():
205
+ clear_btn = gr.Button("Clear chat", variant="secondary", size="sm")
206
+
207
+ gr.HTML("<p style='color:#475569;font-size:0.78rem;margin-top:0.5rem'>Try an example:</p>")
208
+ example_btns = []
209
+ with gr.Row(wrap=True):
210
+ for eq in EXAMPLE_QUERIES:
211
+ b = gr.Button(eq, size="sm", variant="secondary")
212
+ example_btns.append(b)
213
+
214
+ with gr.Column(scale=2):
215
+ gr.HTML("<p style='color:#94a3b8;font-size:0.85rem;font-weight:500;margin-bottom:4px'>Result Table</p>")
216
+ result_table = gr.Dataframe(
217
+ visible=False,
218
+ wrap=True,
219
+ height=220,
220
+ )
221
+ download_file = gr.File(
222
+ label="Download CSV",
223
+ visible=False,
224
+ )
225
+ gr.HTML("<p style='color:#94a3b8;font-size:0.85rem;font-weight:500;margin:1rem 0 4px'>Database Schema</p>")
226
+ gr.Markdown(value=view_schema())
227
+
228
+ status_md = gr.Markdown(visible=False)
229
+ history_state = gr.State([])
230
+
231
+ def run(question, history):
232
+ gen = nl_to_sql_and_run(question, history)
233
+ for h, status, table_update, dl_update in gen:
234
+ yield h, h, status, table_update, dl_update
235
+
236
+ submit_btn.click(
237
+ fn=run,
238
+ inputs=[question_input, history_state],
239
+ outputs=[chatbot, history_state, status_md, result_table, download_file],
240
+ )
241
+ question_input.submit(
242
+ fn=run,
243
+ inputs=[question_input, history_state],
244
+ outputs=[chatbot, history_state, status_md, result_table, download_file],
245
+ )
246
+ clear_btn.click(
247
+ fn=lambda: ([], [], "", gr.update(visible=False), gr.update(visible=False)),
248
+ outputs=[chatbot, history_state, status_md, result_table, download_file],
249
+ )
250
+
251
+ for btn, eq in zip(example_btns, EXAMPLE_QUERIES):
252
+ btn.click(fn=lambda q=eq: q, outputs=[question_input])
253
+
254
+ return demo
255
+
256
+
257
+ if __name__ == "__main__":
258
+ load_model()
259
+ app = create_app()
260
+ app.launch()