Commit ·
41169c9
0
Parent(s):
Clean deploy with LFS for all DB files
Browse files- .gitattributes +4 -0
- .gitignore +1 -0
- Dockerfile +41 -0
- Modelfile.local +62 -0
- README.md +1 -0
- chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/data_level0.bin +3 -0
- chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/header.bin +3 -0
- chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/index_metadata.pickle +3 -0
- chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/length.bin +3 -0
- chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/link_lists.bin +3 -0
- chroma_db/chroma.sqlite3 +3 -0
- docker.md +7 -0
- download_models.py +63 -0
- graph.png +0 -0
- main.py +146 -0
- main_v2.py +431 -0
- requirements.txt +15 -0
- start-ollama.sh +11 -0
.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.gguf filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
chroma_db/**/*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
chroma_db/**/*.pickle filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
unsloth.Q4_K_M.gguf
|
Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
ENV HF_HUB_DISABLE_PROGRESS_BARS=1 # Prevents Hugging Face from showing progress bars
|
| 5 |
+
ENV ANONYMIZED_TELEMETRY=False # Disables telemetry
|
| 6 |
+
ENV PYTHONUNBUFFERED=1 # Prevents Python from buffering output
|
| 7 |
+
|
| 8 |
+
# Install system dependencies
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
python3-dev \
|
| 11 |
+
build-essential \
|
| 12 |
+
curl \
|
| 13 |
+
dos2unix \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Install Ollama
|
| 17 |
+
RUN curl -fsSL https://ollama.com/install.sh | sh
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
RUN pip install --upgrade --no-cache-dir pip && pip install --no-cache-dir -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Set HF_HOME to ensure models are stored in a consistent location
|
| 24 |
+
ENV HF_HOME=/app/hf_cache
|
| 25 |
+
|
| 26 |
+
# Pre-download models
|
| 27 |
+
COPY download_models.py .
|
| 28 |
+
RUN python download_models.py
|
| 29 |
+
|
| 30 |
+
# Copy application files
|
| 31 |
+
COPY . .
|
| 32 |
+
COPY chroma_db /app/chroma_db
|
| 33 |
+
|
| 34 |
+
# Fix line endings and permissions for shell scripts
|
| 35 |
+
RUN dos2unix /app/start-ollama.sh && chmod +x /app/start-ollama.sh
|
| 36 |
+
|
| 37 |
+
# Expose FastAPI port
|
| 38 |
+
EXPOSE 7860
|
| 39 |
+
|
| 40 |
+
# Start Ollama and FastAPI
|
| 41 |
+
CMD ["/bin/sh", "-c", "/app/start-ollama.sh && uvicorn main_v2:app --host 0.0.0.0 --port 7860"]
|
Modelfile.local
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modelfile generated by "ollama show"
|
| 2 |
+
# To build a new Modelfile based on this, replace FROM with:
|
| 3 |
+
# FROM hf.co/phureexd/qwen3_v2_gguf:Q4_K_M
|
| 4 |
+
# You must have .gguf file first
|
| 5 |
+
FROM /app/unsloth.Q4_K_M.gguf
|
| 6 |
+
TEMPLATE """{{- if .Messages }}
|
| 7 |
+
{{- if or .System .Tools }}<|im_start|>system
|
| 8 |
+
{{- if .System }}
|
| 9 |
+
{{ .System }}
|
| 10 |
+
{{- end }}
|
| 11 |
+
{{- if .Tools }}
|
| 12 |
+
|
| 13 |
+
# Tools
|
| 14 |
+
|
| 15 |
+
You may call one or more functions to assist with the user query.
|
| 16 |
+
|
| 17 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 18 |
+
<tools>
|
| 19 |
+
{{- range .Tools }}
|
| 20 |
+
{"type": "function", "function": {{ .Function }}}
|
| 21 |
+
{{- end }}
|
| 22 |
+
</tools>
|
| 23 |
+
|
| 24 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 25 |
+
<tool_call>
|
| 26 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 27 |
+
</tool_call>
|
| 28 |
+
{{- end }}<|im_end|>
|
| 29 |
+
{{ end }}
|
| 30 |
+
{{- range $i, $_ := .Messages }}
|
| 31 |
+
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
| 32 |
+
{{- if eq .Role "user" }}<|im_start|>user
|
| 33 |
+
{{ .Content }}<|im_end|>
|
| 34 |
+
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
| 35 |
+
{{ if .Content }}{{ .Content }}
|
| 36 |
+
{{- else if .ToolCalls }}<tool_call>
|
| 37 |
+
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
| 38 |
+
{{ end }}</tool_call>
|
| 39 |
+
{{- end }}{{ if not $last }}<|im_end|>
|
| 40 |
+
{{ end }}
|
| 41 |
+
{{- else if eq .Role "tool" }}<|im_start|>user
|
| 42 |
+
<tool_response>
|
| 43 |
+
{{ .Content }}
|
| 44 |
+
</tool_response><|im_end|>
|
| 45 |
+
{{ end }}
|
| 46 |
+
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
| 47 |
+
{{ end }}
|
| 48 |
+
{{- end }}
|
| 49 |
+
{{- else }}
|
| 50 |
+
{{- if .System }}<|im_start|>system
|
| 51 |
+
{{ .System }}<|im_end|>
|
| 52 |
+
{{ end }}{{ if .Prompt }}<|im_start|>user
|
| 53 |
+
{{ .Prompt }}<|im_end|>
|
| 54 |
+
{{ end }}<|im_start|>assistant
|
| 55 |
+
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
|
| 56 |
+
PARAMETER temperature 0.7
|
| 57 |
+
PARAMETER top_p 0.8
|
| 58 |
+
PARAMETER top_k 20
|
| 59 |
+
PARAMETER num_predict 512
|
| 60 |
+
PARAMETER repeat_penalty 1
|
| 61 |
+
PARAMETER stop <|im_start|>
|
| 62 |
+
PARAMETER stop <|im_end|>
|
README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"# qwen-healthcare-assistant"
|
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/data_level0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2475b61e969a2d81d22ce88d8c6e9f26f63f6bbf7bc5d834a43f414c368a34b1
|
| 3 |
+
size 8472000
|
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/header.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:01f7190252d8675a30dba956d9378f41be948e807392883f5a95ba08085e0efa
|
| 3 |
+
size 100
|
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/index_metadata.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06f1337839fcfb42cb8148849ea14ad9148b1d24da8fbb177b80ea012228263d
|
| 3 |
+
size 113967
|
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/length.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07995695399aa84304fb2a27cae4b5640590e1a6d85a43b2631c958c4c17ff15
|
| 3 |
+
size 8000
|
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/link_lists.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:732e4883582480d065df2f20149b50ad1ebbbe55c48725c909fde893fc64aabf
|
| 3 |
+
size 16976
|
chroma_db/chroma.sqlite3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef6c315aadf349adb83c6f4b522cce8aafdf726261e6be1ca5e21809409df2fd
|
| 3 |
+
size 50151424
|
docker.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```bash
|
| 2 |
+
docker build -t nlp-app .
|
| 3 |
+
```
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
docker run -p 8000:8000 --name nlp-container nlp-app
|
| 7 |
+
```
|
download_models.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 5 |
+
|
| 6 |
+
# Define models to download
|
| 7 |
+
EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
|
| 8 |
+
CROSS_ENCODER_MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
| 9 |
+
|
| 10 |
+
def download_with_retries(repo_id, retries=5, delay=10):
|
| 11 |
+
"""Downloads a model with retry logic."""
|
| 12 |
+
for i in range(retries):
|
| 13 |
+
try:
|
| 14 |
+
print(f"Downloading {repo_id} (Attempt {i+1}/{retries})...")
|
| 15 |
+
# resume_download=True ensures we don't start from scratch if interrupted
|
| 16 |
+
snapshot_download(repo_id=repo_id, resume_download=True)
|
| 17 |
+
print(f"Successfully downloaded {repo_id}")
|
| 18 |
+
return
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"Error downloading {repo_id}: {e}")
|
| 21 |
+
if i < retries - 1:
|
| 22 |
+
print(f"Retrying in {delay} seconds...")
|
| 23 |
+
time.sleep(delay)
|
| 24 |
+
else:
|
| 25 |
+
print(f"Failed to download {repo_id} after {retries} attempts.")
|
| 26 |
+
raise e
|
| 27 |
+
|
| 28 |
+
def download_models():
|
| 29 |
+
print(f"Downloading embedding model: {EMBEDDING_MODEL_NAME}")
|
| 30 |
+
download_with_retries(EMBEDDING_MODEL_NAME)
|
| 31 |
+
|
| 32 |
+
# Also initialize SentenceTransformer to ensure it caches correctly for the library
|
| 33 |
+
print(f"Initializing SentenceTransformer for {EMBEDDING_MODEL_NAME} to populate cache...")
|
| 34 |
+
try:
|
| 35 |
+
SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"Warning: Failed to initialize SentenceTransformer: {e}")
|
| 38 |
+
|
| 39 |
+
print(f"Downloading cross-encoder model: {CROSS_ENCODER_MODEL_NAME}")
|
| 40 |
+
download_with_retries(CROSS_ENCODER_MODEL_NAME)
|
| 41 |
+
|
| 42 |
+
# Initialize CrossEncoder to populate cache
|
| 43 |
+
print(f"Initializing CrossEncoder for {CROSS_ENCODER_MODEL_NAME} to populate cache...")
|
| 44 |
+
try:
|
| 45 |
+
CrossEncoder(CROSS_ENCODER_MODEL_NAME)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Warning: Failed to initialize CrossEncoder: {e}")
|
| 48 |
+
|
| 49 |
+
# Download GGUF model
|
| 50 |
+
llm_repo_id = "phureexd/qwen3_v2_gguf"
|
| 51 |
+
llm_filename = "unsloth.Q4_K_M.gguf"
|
| 52 |
+
print(f"Downloading LLM: {llm_filename} from {llm_repo_id}")
|
| 53 |
+
try:
|
| 54 |
+
hf_hub_download(repo_id=llm_repo_id, filename=llm_filename, local_dir=".", local_dir_use_symlinks=False)
|
| 55 |
+
print(f"Successfully downloaded {llm_filename}")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error downloading LLM: {e}")
|
| 58 |
+
raise e
|
| 59 |
+
|
| 60 |
+
print("All models downloaded successfully.")
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
download_models()
|
graph.png
ADDED
|
main.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from unsloth import FastLanguageModel
|
| 7 |
+
|
| 8 |
+
print()
|
| 9 |
+
import time
|
| 10 |
+
from threading import Thread
|
| 11 |
+
|
| 12 |
+
import uvicorn
|
| 13 |
+
from fastapi import FastAPI
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import StreamingResponse
|
| 16 |
+
from transformers import TextIteratorStreamer
|
| 17 |
+
|
| 18 |
+
# Initialize FastAPI app
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
|
| 21 |
+
# Enable CORS to allow frontend requests
|
| 22 |
+
app.add_middleware(
|
| 23 |
+
CORSMiddleware,
|
| 24 |
+
allow_origins=["*"],
|
| 25 |
+
allow_credentials=True,
|
| 26 |
+
allow_methods=["*"],
|
| 27 |
+
allow_headers=["*"],
|
| 28 |
+
)
|
| 29 |
+
###########################################################################
|
| 30 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 31 |
+
|
| 32 |
+
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
|
| 33 |
+
|
| 34 |
+
from langchain_chroma import Chroma
|
| 35 |
+
|
| 36 |
+
vector_store = Chroma(
|
| 37 |
+
embedding_function=embeddings,
|
| 38 |
+
persist_directory="C:/Users/LENOVO/Downloads/chroma_langchain_db_3", # Where to save data locally, remove if not necessary
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Model configuration
|
| 43 |
+
# model_name = "phureexd/qwen_model"
|
| 44 |
+
model_name = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit"
|
| 45 |
+
max_seq_length = 2048
|
| 46 |
+
dtype = None
|
| 47 |
+
load_in_4bit = True
|
| 48 |
+
|
| 49 |
+
# Load model and tokenizer
|
| 50 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 51 |
+
model_name=model_name,
|
| 52 |
+
max_seq_length=max_seq_length,
|
| 53 |
+
dtype=dtype,
|
| 54 |
+
load_in_4bit=load_in_4bit,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
FastLanguageModel.for_inference(model)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Step 3: Prepare RAG messages
|
| 61 |
+
def prepare_rag_messages(messages, vector_store, k=2):
|
| 62 |
+
query = next(msg["content"] for msg in reversed(messages) if msg["role"] == "user")
|
| 63 |
+
print("this is query:\n", type(query), query)
|
| 64 |
+
docs = vector_store.similarity_search(query, k=k)
|
| 65 |
+
context = "\n\n".join(
|
| 66 |
+
f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in docs
|
| 67 |
+
)
|
| 68 |
+
print("this is context:\n", context)
|
| 69 |
+
system_message = messages[0]["content"] + "\n\nContext:\n" + context
|
| 70 |
+
rag_messages = [
|
| 71 |
+
{"role": "system", "content": system_message},
|
| 72 |
+
{"role": "user", "content": query},
|
| 73 |
+
]
|
| 74 |
+
return rag_messages
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Endpoint to generate response using GET
|
| 78 |
+
@app.get("/generate")
|
| 79 |
+
async def generate(query: str):
|
| 80 |
+
|
| 81 |
+
messages = [
|
| 82 |
+
{
|
| 83 |
+
"role": "system",
|
| 84 |
+
"content": f"""You are a medical professional assistant. You will receive user queries along with relevant context retrieved via RAG.
|
| 85 |
+
Use the context if it is relevant. If not, rely on your own medical knowledge. If unsure, clearly state so.
|
| 86 |
+
Always respond in the same language used in the user's query. Keep responses clear, concise, and professional.
|
| 87 |
+
|
| 88 |
+
Extremely important: Answer in the same language as the user query.
|
| 89 |
+
""",
|
| 90 |
+
},
|
| 91 |
+
{"role": "user", "content": f"{query}"},
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
rag_messages = prepare_rag_messages(messages, vector_store, k=2)
|
| 95 |
+
|
| 96 |
+
text = tokenizer.apply_chat_template(
|
| 97 |
+
rag_messages,
|
| 98 |
+
tokenize=False,
|
| 99 |
+
add_generation_prompt=True, # Must add for generation
|
| 100 |
+
enable_thinking=False,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
inputs = tokenizer(text, return_tensors="pt").to(device=model.device)
|
| 104 |
+
|
| 105 |
+
def stream_response():
|
| 106 |
+
streamer = TextIteratorStreamer(
|
| 107 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 108 |
+
)
|
| 109 |
+
# * the recommended settings for reasoning inference are temperature = 0.6, top_p = 0.95, top_k = 20
|
| 110 |
+
# * For normal chat based inference, temperature = 0.7, top_p = 0.8, top_k = 20
|
| 111 |
+
# generate_kwargs = dict(
|
| 112 |
+
# **inputs,
|
| 113 |
+
# max_new_tokens=2048,
|
| 114 |
+
# do_sample=True,
|
| 115 |
+
# temperature=0.6,
|
| 116 |
+
# top_p=0.95,
|
| 117 |
+
# top_k=20,
|
| 118 |
+
# streamer=streamer,
|
| 119 |
+
# )
|
| 120 |
+
generate_kwargs = dict(
|
| 121 |
+
**inputs,
|
| 122 |
+
max_new_tokens=1024,
|
| 123 |
+
do_sample=True,
|
| 124 |
+
temperature=0.7,
|
| 125 |
+
top_p=0.8,
|
| 126 |
+
top_k=20,
|
| 127 |
+
streamer=streamer,
|
| 128 |
+
)
|
| 129 |
+
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 130 |
+
# thread.daemon = True
|
| 131 |
+
thread.start()
|
| 132 |
+
|
| 133 |
+
for new_text in streamer:
|
| 134 |
+
yield f"data: {new_text}\n\n"
|
| 135 |
+
# time.sleep(0.01)
|
| 136 |
+
|
| 137 |
+
return StreamingResponse(
|
| 138 |
+
stream_response(),
|
| 139 |
+
media_type="text/event-stream",
|
| 140 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# Run the server
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
main_v2.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Suppress TensorFlow oneDNN optimization messages if not needed
|
| 4 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 5 |
+
# Disable ChromaDB telemetry to prevent log errors
|
| 6 |
+
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
| 7 |
+
import uvicorn
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.responses import StreamingResponse
|
| 11 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 12 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 13 |
+
from langchain.tools.retriever import create_retriever_tool
|
| 14 |
+
from langchain_chroma import Chroma
|
| 15 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 16 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
| 17 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 18 |
+
from langchain_ollama import ChatOllama
|
| 19 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 20 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 21 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 22 |
+
|
| 23 |
+
# Set the device for HuggingFace models
|
| 24 |
+
device = "cpu"
|
| 25 |
+
|
| 26 |
+
# --- Configuration Constants ---
|
| 27 |
+
APP_HOST = "0.0.0.0"
|
| 28 |
+
APP_PORT = 7860
|
| 29 |
+
|
| 30 |
+
THREAD_ID = "global_health_chat_session" # Unique ID for the chat session
|
| 31 |
+
|
| 32 |
+
# Models and Paths
|
| 33 |
+
EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
|
| 34 |
+
CROSS_ENCODER_MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
| 35 |
+
LLM_MODEL_NAME = "custom-model" # Replace with your actual model, e.g., "hf.co/phureexd/qwen3_v2_gguf:Q4_K_M"
|
| 36 |
+
VECTOR_DB_PATH = "/app/chroma_db" if os.path.exists("/app/chroma_db") else "chroma_db"
|
| 37 |
+
|
| 38 |
+
# LLM Parameters
|
| 39 |
+
LLM_TEMPERATURE = 0.7
|
| 40 |
+
LLM_TOP_P = 0.8
|
| 41 |
+
LLM_TOP_K = 20
|
| 42 |
+
LLM_NUM_PREDICT = 512
|
| 43 |
+
|
| 44 |
+
# Retriever Parameters
|
| 45 |
+
RETRIEVER_SEARCH_K = 6 # Number of documents to fetch initially
|
| 46 |
+
RERANKER_TOP_N = 3 # Number of documents after reranking
|
| 47 |
+
|
| 48 |
+
# --- System Prompts ---
|
| 49 |
+
|
| 50 |
+
INITIAL_SYSTEM_MESSAGE = SystemMessage(
|
| 51 |
+
content="""
|
| 52 |
+
You are a health assistant designed to answer questions related to health, wellness, nutrition, exercise, symptoms, diseases, prevention, treatment, mental health, and medical advice. This explicitly includes general statements about feeling unwell or sick (e.g., "I'm sick", "I don't feel good"). For ANY query that falls into these categories, you MUST use the retrieve_health_info tool to fetch relevant information from the database before providing an answer. This ensures your responses are accurate and based on trusted sources. Do not answer health-related questions directly without using the tool, even if you think you know the answer.
|
| 53 |
+
|
| 54 |
+
If the query is clearly unrelated to health (e.g., general knowledge questions), you can answer directly without the tool.
|
| 55 |
+
|
| 56 |
+
**Important Guidelines:**
|
| 57 |
+
- If the query mentions or implies health, feeling unwell, sickness, treatment, symptoms, diseases, nutrition, exercise, mental health, or wellness, use the tool.
|
| 58 |
+
- Even if the query is only slightly related to health, or is a general statement about feeling unwell, use the tool to provide an informed answer.
|
| 59 |
+
- Always respond in the same language as the user's query.
|
| 60 |
+
- When in doubt, err on the side of using the tool.
|
| 61 |
+
|
| 62 |
+
**Examples:**
|
| 63 |
+
|
| 64 |
+
1. **Health-Related (Use Tool):**
|
| 65 |
+
- User: "What are the symptoms of diabetes?"
|
| 66 |
+
- Assistant: [Uses retrieve_health_info tool] "Common symptoms of diabetes include frequent urination, excessive thirst, and fatigue."
|
| 67 |
+
|
| 68 |
+
2. **Slightly Health-Related (Use Tool):**
|
| 69 |
+
- User: "Is it okay to exercise when I have a cold?"
|
| 70 |
+
- Assistant: [Uses retrieve_health_info tool] "Light exercise might be okay, but rest if you have a fever."
|
| 71 |
+
|
| 72 |
+
3. **General Sickness Statement (Use Tool):**
|
| 73 |
+
- User: "I'm sick."
|
| 74 |
+
- Assistant: [Uses retrieve_health_info tool] "I'm sorry to hear you're not feeling well. Common advice includes resting and staying hydrated. If you have specific symptoms, I can try to provide more information."
|
| 75 |
+
|
| 76 |
+
4. **Non-Health-Related (No Tool):**
|
| 77 |
+
- User: "What is the capital of France?"
|
| 78 |
+
- Assistant: "The capital of France is Paris."
|
| 79 |
+
|
| 80 |
+
5. **Health-Related in Thai (Use Tool):**
|
| 81 |
+
- User: "อาการของโรคเบาหวานมีอะไรบ้าง?"
|
| 82 |
+
- Assistant: [Uses retrieve_health_info tool] "อาการทั่วไปของโรคเบาหวาน ได้แก่ ปัสสาวะบ่อย กระหายน้ำมาก และอ่อนเพลีย"
|
| 83 |
+
|
| 84 |
+
6. **Non-Health-Related in Thai (No Tool):**
|
| 85 |
+
- User: "เมืองหลวงของฝรั่งเศสคืออะไร?"
|
| 86 |
+
- Assistant: "เมืองหลวงของฝรั่งเศสคือปารีส"
|
| 87 |
+
/no_think
|
| 88 |
+
"""
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
RAG_SYSTEM_PROMPT_TEMPLATE = """
|
| 92 |
+
You are a health assistant for question-answering tasks.
|
| 93 |
+
Use the following pieces of retrieved documents to answer the question.
|
| 94 |
+
If you don't know the answer, say that you don't know.
|
| 95 |
+
Keep the answer concise and accurate.
|
| 96 |
+
|
| 97 |
+
**Extremely important: Answer in the same language as the user query.**
|
| 98 |
+
|
| 99 |
+
### Retrieved documents (if applicable):
|
| 100 |
+
{docs_content}
|
| 101 |
+
|
| 102 |
+
### Examples of the language model's responses:
|
| 103 |
+
**Example 1 (English):**
|
| 104 |
+
User: I feel a bit tired, what could it be?
|
| 105 |
+
Assistant: Fatigue can be caused by lack of sleep, stress, or dehydration. Ensure you get 7-8 hours of sleep and stay hydrated.
|
| 106 |
+
|
| 107 |
+
**Example 2 (English):**
|
| 108 |
+
User: Does coffee affect my health?
|
| 109 |
+
Assistant: Moderate coffee consumption can improve alertness but may cause insomnia or anxiety if overconsumed.
|
| 110 |
+
|
| 111 |
+
**Example 3 (Thai):**
|
| 112 |
+
User: ฉันรู้สึกเหนื่อยเล็กน้อย เกิดจากอะไรได้บ้าง?
|
| 113 |
+
Assistant: อาการเหนื่อยอาจเกิดจากการนอนหลับไม่เพียงพอ ความเครียด หรือภาวะขาดน้ำ ควรนอนหลับ 7-8 ชั่วโมงและดื่มน้ำให้เพียงพอ
|
| 114 |
+
/no_think
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# --- Initialization of Langchain Components ---
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def init_embeddings(model_name: str):
|
| 121 |
+
"""Initializes HuggingFace embeddings."""
|
| 122 |
+
return HuggingFaceEmbeddings(model_name=model_name)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def init_vector_store(embedding_function, persist_directory: str):
|
| 126 |
+
"""Initializes Chroma vector store."""
|
| 127 |
+
return Chroma(
|
| 128 |
+
embedding_function=embedding_function,
|
| 129 |
+
persist_directory=persist_directory,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def init_llm(
|
| 134 |
+
model_name: str, temperature: float, top_p: float, top_k: int, num_predict: int
|
| 135 |
+
):
|
| 136 |
+
"""Initializes ChatOllama LLM."""
|
| 137 |
+
return ChatOllama(
|
| 138 |
+
model=model_name,
|
| 139 |
+
temperature=temperature,
|
| 140 |
+
top_p=top_p,
|
| 141 |
+
top_k=top_k,
|
| 142 |
+
num_predict=num_predict,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def init_retriever_tool(
|
| 147 |
+
vector_store_instance,
|
| 148 |
+
cross_encoder_model_name: str,
|
| 149 |
+
base_retriever_k: int,
|
| 150 |
+
reranker_top_n: int,
|
| 151 |
+
):
|
| 152 |
+
"""Initializes the retriever tool with reranking."""
|
| 153 |
+
base_retriever = vector_store_instance.as_retriever(
|
| 154 |
+
search_kwargs={"k": base_retriever_k}
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
cross_encoder = HuggingFaceCrossEncoder(
|
| 158 |
+
model_name=cross_encoder_model_name,
|
| 159 |
+
model_kwargs={"device": device}, # Specify device if needed, e.g., "cuda"
|
| 160 |
+
)
|
| 161 |
+
reranker = CrossEncoderReranker(model=cross_encoder, top_n=reranker_top_n)
|
| 162 |
+
|
| 163 |
+
compression_retriever = ContextualCompressionRetriever(
|
| 164 |
+
base_compressor=reranker,
|
| 165 |
+
base_retriever=base_retriever,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return create_retriever_tool(
|
| 169 |
+
retriever=compression_retriever,
|
| 170 |
+
name="retrieve_health_info",
|
| 171 |
+
description=(
|
| 172 |
+
"Use this tool to retrieve relevant documents from the query related to health, "
|
| 173 |
+
"wellness, nutrition, exercise, symptoms, diseases, treatment, prevention, "
|
| 174 |
+
"mental health, or medical advice information from the database. "
|
| 175 |
+
"Even if the query is slightly related. "
|
| 176 |
+
f"Return the top {reranker_top_n} most relevant documents."
|
| 177 |
+
),
|
| 178 |
+
response_format="content_and_artifact", # Ensures artifact contains Document objects
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Initialize components
|
| 183 |
+
print("Initializing Embeddings...")
|
| 184 |
+
embeddings = init_embeddings(EMBEDDING_MODEL_NAME)
|
| 185 |
+
print("Embeddings Initialized.")
|
| 186 |
+
|
| 187 |
+
print("Initializing Vector Store...")
|
| 188 |
+
vector_store = init_vector_store(embeddings, VECTOR_DB_PATH)
|
| 189 |
+
print("Vector Store Initialized.")
|
| 190 |
+
|
| 191 |
+
print("Initializing LLM...")
|
| 192 |
+
llm = init_llm(LLM_MODEL_NAME, LLM_TEMPERATURE, LLM_TOP_P, LLM_TOP_K, LLM_NUM_PREDICT)
|
| 193 |
+
print("LLM Initialized.")
|
| 194 |
+
|
| 195 |
+
print("Initializing Retriever Tool...")
|
| 196 |
+
retriever_tool = init_retriever_tool(
|
| 197 |
+
vector_store, CROSS_ENCODER_MODEL_NAME, RETRIEVER_SEARCH_K, RERANKER_TOP_N
|
| 198 |
+
)
|
| 199 |
+
print("Retriever Tool Initialized.")
|
| 200 |
+
|
| 201 |
+
# --- LangGraph Node Definitions ---
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
async def query_or_respond_node_logic(state: MessagesState):
|
| 205 |
+
"""
|
| 206 |
+
Node function: Decides whether to call a tool for retrieval or respond directly.
|
| 207 |
+
Binds the retriever_tool to the LLM for this decision.
|
| 208 |
+
"""
|
| 209 |
+
response = await llm.bind_tools([retriever_tool]).ainvoke(state["messages"])
|
| 210 |
+
return {"messages": [response]}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
async def generate_rag_response_node_logic(state: MessagesState):
|
| 214 |
+
"""
|
| 215 |
+
Node function: Generates a response using retrieved documents (if any).
|
| 216 |
+
"""
|
| 217 |
+
# Extract the most recent contiguous block of tool messages
|
| 218 |
+
recent_tool_messages = []
|
| 219 |
+
for message in reversed(state["messages"]):
|
| 220 |
+
if message.type == "tool": # or isinstance(message, ToolMessage)
|
| 221 |
+
recent_tool_messages.append(message)
|
| 222 |
+
else:
|
| 223 |
+
break
|
| 224 |
+
tool_messages = recent_tool_messages[::-1]
|
| 225 |
+
|
| 226 |
+
# Format retrieved document content for the prompt
|
| 227 |
+
doc_strings = []
|
| 228 |
+
for tool_msg in tool_messages:
|
| 229 |
+
# Ensure artifact is a list of Langchain Document objects
|
| 230 |
+
if hasattr(tool_msg, "artifact") and isinstance(tool_msg.artifact, list):
|
| 231 |
+
for doc in tool_msg.artifact:
|
| 232 |
+
if hasattr(doc, "page_content") and hasattr(
|
| 233 |
+
doc, "metadata"
|
| 234 |
+
): # Document structure check
|
| 235 |
+
source = doc.metadata.get("source", "Unknown source")
|
| 236 |
+
content = doc.page_content
|
| 237 |
+
doc_strings.append(f"Source: {source}\nContent: {content}")
|
| 238 |
+
|
| 239 |
+
docs_content = (
|
| 240 |
+
"\n\n".join(doc_strings)
|
| 241 |
+
if doc_strings
|
| 242 |
+
else "No relevant documents were found to answer the current question."
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Prepare messages for the generation LLM call (history + new system prompt with docs)
|
| 246 |
+
# Include human messages, initial system messages, and AI responses (not tool calls)
|
| 247 |
+
conversation_history_for_llm = [
|
| 248 |
+
msg
|
| 249 |
+
for msg in state["messages"]
|
| 250 |
+
if msg.type in ("human", "system") or (msg.type == "ai" and not msg.tool_calls)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
# Construct the system prompt with retrieved documents
|
| 254 |
+
current_system_prompt_content = RAG_SYSTEM_PROMPT_TEMPLATE.format(
|
| 255 |
+
docs_content=docs_content
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
prompt_for_generation = [
|
| 259 |
+
SystemMessage(content=current_system_prompt_content)
|
| 260 |
+
] + conversation_history_for_llm
|
| 261 |
+
|
| 262 |
+
response = await llm.ainvoke(prompt_for_generation)
|
| 263 |
+
return {"messages": [response]}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# --- LangGraph Graph Construction ---
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def create_lang_graph(checkpointer_instance):
|
| 270 |
+
"""Creates and compiles the LangGraph."""
|
| 271 |
+
graph_builder = StateGraph(MessagesState)
|
| 272 |
+
|
| 273 |
+
# Define nodes
|
| 274 |
+
graph_builder.add_node("query_or_respond", query_or_respond_node_logic)
|
| 275 |
+
tools_node = ToolNode([retriever_tool]) # Define tool execution node
|
| 276 |
+
graph_builder.add_node("tools", tools_node)
|
| 277 |
+
graph_builder.add_node("generate_rag_response", generate_rag_response_node_logic)
|
| 278 |
+
|
| 279 |
+
# Define edges
|
| 280 |
+
graph_builder.set_entry_point("query_or_respond")
|
| 281 |
+
graph_builder.add_conditional_edges(
|
| 282 |
+
"query_or_respond",
|
| 283 |
+
tools_condition, # Prebuilt condition to check for tool calls
|
| 284 |
+
{END: END, "tools": "tools"},
|
| 285 |
+
)
|
| 286 |
+
graph_builder.add_edge("tools", "generate_rag_response")
|
| 287 |
+
graph_builder.add_edge("generate_rag_response", END)
|
| 288 |
+
|
| 289 |
+
return graph_builder.compile(checkpointer=checkpointer_instance)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# Initialize checkpointer and compile graph
|
| 293 |
+
memory_saver = MemorySaver()
|
| 294 |
+
graph = create_lang_graph(memory_saver)
|
| 295 |
+
|
| 296 |
+
# Optional: Save graph visualization
|
| 297 |
+
# try:
|
| 298 |
+
# graph.get_graph().draw_mermaid_png(output_file_path="graph.png")
|
| 299 |
+
# print("Graph visualization saved to graph.png")
|
| 300 |
+
# except Exception as e:
|
| 301 |
+
# print(f"Could not save graph visualization: {e}")
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# --- FastAPI Application Setup ---
|
| 305 |
+
app = FastAPI(
|
| 306 |
+
title="Health Assistant API",
|
| 307 |
+
description="API for a health assistant using a retrieval-augmented generation approach.",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
app.add_middleware(
|
| 311 |
+
CORSMiddleware,
|
| 312 |
+
allow_origins=["*"], # Allow all origins for simplicity; restrict in production
|
| 313 |
+
allow_credentials=True,
|
| 314 |
+
allow_methods=["*"],
|
| 315 |
+
allow_headers=["*"],
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# --- FastAPI Endpoints ---
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
@app.get("/generate", summary="Generate a response from the health assistant model")
|
| 322 |
+
async def generate_endpoint(
|
| 323 |
+
query: str,
|
| 324 |
+
useRAG: bool = False,
|
| 325 |
+
thread_id: str = THREAD_ID,
|
| 326 |
+
):
|
| 327 |
+
"""
|
| 328 |
+
Handles a user query, streams back the assistant's responses.
|
| 329 |
+
- `query`: The user's question.
|
| 330 |
+
- `useRAG`: If true, forces the use of the retrieval tool via a system message.
|
| 331 |
+
- `thread_id`: Unique identifier for the conversation session.
|
| 332 |
+
"""
|
| 333 |
+
print(f"Received query: '{query}', Force RAG: {useRAG}, Thread ID: {thread_id}")
|
| 334 |
+
|
| 335 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 336 |
+
|
| 337 |
+
# Prepare input messages for the graph
|
| 338 |
+
input_messages = []
|
| 339 |
+
current_checkpoint_tuple = memory_saver.get_tuple(config) # Check if history exists
|
| 340 |
+
|
| 341 |
+
if current_checkpoint_tuple is None: # No history, it's a new or cleared session
|
| 342 |
+
input_messages.append(INITIAL_SYSTEM_MESSAGE)
|
| 343 |
+
print("Starting new conversation: Added initial system message.")
|
| 344 |
+
|
| 345 |
+
if useRAG:
|
| 346 |
+
# This message is added to strongly encourage tool use for the current query,
|
| 347 |
+
# supplementing the INITIAL_SYSTEM_MESSAGE.
|
| 348 |
+
input_messages.append(
|
| 349 |
+
SystemMessage(
|
| 350 |
+
content="You MUST use the retrieve_health_info tool for this query even if it seems unrelated to health."
|
| 351 |
+
)
|
| 352 |
+
)
|
| 353 |
+
print("Forcing RAG for this query with an additional system message.")
|
| 354 |
+
|
| 355 |
+
input_messages.append(HumanMessage(content=query))
|
| 356 |
+
graph_input = {"messages": input_messages}
|
| 357 |
+
|
| 358 |
+
async def stream_response_events():
|
| 359 |
+
# graph.stream with stream_mode="messages" yields the ENTIRE list of messages
|
| 360 |
+
# in the current state each time a node completes.
|
| 361 |
+
async for messages_in_state in graph.astream(
|
| 362 |
+
graph_input, config, stream_mode="messages"
|
| 363 |
+
):
|
| 364 |
+
if not messages_in_state:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
# Get the current message from the state
|
| 368 |
+
latest_message = messages_in_state[0]
|
| 369 |
+
|
| 370 |
+
if isinstance(latest_message, AIMessage):
|
| 371 |
+
if latest_message.content: # Final textual response
|
| 372 |
+
# print(
|
| 373 |
+
# f"Streaming AI content: {latest_message.content}"
|
| 374 |
+
# )
|
| 375 |
+
yield f"data: {latest_message.content}\n\n"
|
| 376 |
+
elif latest_message.tool_calls: # AI message requesting a tool call
|
| 377 |
+
print(f"AI requested Tool call: {latest_message.tool_calls}")
|
| 378 |
+
# You might want to send a status to the client, e.g., "Thinking..." or "Retrieving info..."
|
| 379 |
+
# yield f"event: tool_call\ndata: {json.dumps(latest_message.tool_calls)}\n\n"
|
| 380 |
+
elif isinstance(
|
| 381 |
+
latest_message, ToolMessage
|
| 382 |
+
): # Message containing tool execution results
|
| 383 |
+
if latest_message.name == "retrieve_health_info" and hasattr(
|
| 384 |
+
latest_message, "artifact"
|
| 385 |
+
):
|
| 386 |
+
print(f"Tool '{latest_message.name}' executed. Artifact content:")
|
| 387 |
+
if latest_message.artifact and isinstance(
|
| 388 |
+
latest_message.artifact, list
|
| 389 |
+
):
|
| 390 |
+
# print every document in the artifact
|
| 391 |
+
source_list = set()
|
| 392 |
+
for doc in latest_message.artifact:
|
| 393 |
+
source = doc.metadata.get("source", "Unknown source")
|
| 394 |
+
|
| 395 |
+
if source != "Unknown source":
|
| 396 |
+
source_list.add(source)
|
| 397 |
+
|
| 398 |
+
print(f" Source: {source}\n Content: {doc.page_content}")
|
| 399 |
+
yield f"data: **Source:**{str(source_list)}\n\n"
|
| 400 |
+
|
| 401 |
+
return StreamingResponse(
|
| 402 |
+
stream_response_events(),
|
| 403 |
+
media_type="text/event-stream",
|
| 404 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@app.get("/clear", summary="Clear conversation history")
|
| 409 |
+
async def clear_conversation_endpoint(thread_id: str = THREAD_ID):
|
| 410 |
+
"""Clears the conversation history for the specified thread_id."""
|
| 411 |
+
try:
|
| 412 |
+
# Note: MemorySaver in some versions might not support explicit deletion easily via public API
|
| 413 |
+
# This is a best-effort attempt or placeholder for actual persistence deletion
|
| 414 |
+
# If using a real DB checkpointer, you would delete rows here.
|
| 415 |
+
# For MemorySaver, we might just need to reset the state or let it be if it's per-request instance (it's not here).
|
| 416 |
+
# Actually, MemorySaver stores in a dict. We can try accessing it if we really need to clear.
|
| 417 |
+
if hasattr(memory_saver, "storage"):
|
| 418 |
+
if thread_id in memory_saver.storage:
|
| 419 |
+
del memory_saver.storage[thread_id]
|
| 420 |
+
|
| 421 |
+
print(f"Conversation history cleared for thread_id: {thread_id}")
|
| 422 |
+
return {"status": "success", "message": "Conversation history cleared."}
|
| 423 |
+
except Exception as e:
|
| 424 |
+
print(f"Error clearing conversation history for thread_id {thread_id}: {e}")
|
| 425 |
+
return {"status": "error", "message": f"Failed to clear history: {e}"}
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# --- Main Execution ---
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
print(f"Starting Health Assistant API on {APP_HOST}:{APP_PORT}")
|
| 431 |
+
uvicorn.run(app, host=APP_HOST, port=APP_PORT)
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.12
|
| 2 |
+
uvicorn==0.34.2
|
| 3 |
+
langchain==0.3.25
|
| 4 |
+
langchain-core==0.3.58
|
| 5 |
+
langchain-chroma==0.2.3
|
| 6 |
+
langchain-huggingface==0.1.2
|
| 7 |
+
langchain-ollama==0.3.2
|
| 8 |
+
langchain-community==0.3.23
|
| 9 |
+
langgraph==0.4.1
|
| 10 |
+
chromadb==0.6.3
|
| 11 |
+
huggingface-hub==0.30.2
|
| 12 |
+
sentence-transformers==3.4.1
|
| 13 |
+
transformers==4.51.3
|
| 14 |
+
aiocron==1.8
|
| 15 |
+
aiohttp==3.11.11
|
start-ollama.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
# Start Ollama in the background
|
| 4 |
+
ollama serve &
|
| 5 |
+
|
| 6 |
+
# Wait for Ollama to start
|
| 7 |
+
sleep 5
|
| 8 |
+
|
| 9 |
+
# Create a custom model using the Modelfile
|
| 10 |
+
ollama create custom-model -f /app/Modelfile.local
|
| 11 |
+
|