phureexd commited on
Commit
41169c9
·
0 Parent(s):

Clean deploy with LFS for all DB files

Browse files
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.gguf filter=lfs diff=lfs merge=lfs -text
2
+ chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
3
+ chroma_db/**/*.bin filter=lfs diff=lfs merge=lfs -text
4
+ chroma_db/**/*.pickle filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ unsloth.Q4_K_M.gguf
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+ ENV HF_HUB_DISABLE_PROGRESS_BARS=1 # Prevents Hugging Face from showing progress bars
5
+ ENV ANONYMIZED_TELEMETRY=False # Disables telemetry
6
+ ENV PYTHONUNBUFFERED=1 # Prevents Python from buffering output
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y \
10
+ python3-dev \
11
+ build-essential \
12
+ curl \
13
+ dos2unix \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Install Ollama
17
+ RUN curl -fsSL https://ollama.com/install.sh | sh
18
+
19
+ # Install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install --upgrade --no-cache-dir pip && pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Set HF_HOME to ensure models are stored in a consistent location
24
+ ENV HF_HOME=/app/hf_cache
25
+
26
+ # Pre-download models
27
+ COPY download_models.py .
28
+ RUN python download_models.py
29
+
30
+ # Copy application files
31
+ COPY . .
32
+ COPY chroma_db /app/chroma_db
33
+
34
+ # Fix line endings and permissions for shell scripts
35
+ RUN dos2unix /app/start-ollama.sh && chmod +x /app/start-ollama.sh
36
+
37
+ # Expose FastAPI port
38
+ EXPOSE 7860
39
+
40
+ # Start Ollama and FastAPI
41
+ CMD ["/bin/sh", "-c", "/app/start-ollama.sh && uvicorn main_v2:app --host 0.0.0.0 --port 7860"]
Modelfile.local ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modelfile generated by "ollama show"
2
+ # To build a new Modelfile based on this, replace FROM with:
3
+ # FROM hf.co/phureexd/qwen3_v2_gguf:Q4_K_M
4
+ # You must have .gguf file first
5
+ FROM /app/unsloth.Q4_K_M.gguf
6
+ TEMPLATE """{{- if .Messages }}
7
+ {{- if or .System .Tools }}<|im_start|>system
8
+ {{- if .System }}
9
+ {{ .System }}
10
+ {{- end }}
11
+ {{- if .Tools }}
12
+
13
+ # Tools
14
+
15
+ You may call one or more functions to assist with the user query.
16
+
17
+ You are provided with function signatures within <tools></tools> XML tags:
18
+ <tools>
19
+ {{- range .Tools }}
20
+ {"type": "function", "function": {{ .Function }}}
21
+ {{- end }}
22
+ </tools>
23
+
24
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
25
+ <tool_call>
26
+ {"name": <function-name>, "arguments": <args-json-object>}
27
+ </tool_call>
28
+ {{- end }}<|im_end|>
29
+ {{ end }}
30
+ {{- range $i, $_ := .Messages }}
31
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
32
+ {{- if eq .Role "user" }}<|im_start|>user
33
+ {{ .Content }}<|im_end|>
34
+ {{ else if eq .Role "assistant" }}<|im_start|>assistant
35
+ {{ if .Content }}{{ .Content }}
36
+ {{- else if .ToolCalls }}<tool_call>
37
+ {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
38
+ {{ end }}</tool_call>
39
+ {{- end }}{{ if not $last }}<|im_end|>
40
+ {{ end }}
41
+ {{- else if eq .Role "tool" }}<|im_start|>user
42
+ <tool_response>
43
+ {{ .Content }}
44
+ </tool_response><|im_end|>
45
+ {{ end }}
46
+ {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
47
+ {{ end }}
48
+ {{- end }}
49
+ {{- else }}
50
+ {{- if .System }}<|im_start|>system
51
+ {{ .System }}<|im_end|>
52
+ {{ end }}{{ if .Prompt }}<|im_start|>user
53
+ {{ .Prompt }}<|im_end|>
54
+ {{ end }}<|im_start|>assistant
55
+ {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
56
+ PARAMETER temperature 0.7
57
+ PARAMETER top_p 0.8
58
+ PARAMETER top_k 20
59
+ PARAMETER num_predict 512
60
+ PARAMETER repeat_penalty 1
61
+ PARAMETER stop <|im_start|>
62
+ PARAMETER stop <|im_end|>
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ "# qwen-healthcare-assistant"
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2475b61e969a2d81d22ce88d8c6e9f26f63f6bbf7bc5d834a43f414c368a34b1
3
+ size 8472000
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01f7190252d8675a30dba956d9378f41be948e807392883f5a95ba08085e0efa
3
+ size 100
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06f1337839fcfb42cb8148849ea14ad9148b1d24da8fbb177b80ea012228263d
3
+ size 113967
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07995695399aa84304fb2a27cae4b5640590e1a6d85a43b2631c958c4c17ff15
3
+ size 8000
chroma_db/ba17ee65-4350-4399-8b7f-ca4660b2aab0/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:732e4883582480d065df2f20149b50ad1ebbbe55c48725c909fde893fc64aabf
3
+ size 16976
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef6c315aadf349adb83c6f4b522cce8aafdf726261e6be1ca5e21809409df2fd
3
+ size 50151424
docker.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ```bash
2
+ docker build -t nlp-app .
3
+ ```
4
+
5
+ ```bash
6
+ docker run -p 8000:8000 --name nlp-container nlp-app
7
+ ```
download_models.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from huggingface_hub import snapshot_download, hf_hub_download
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
+
6
+ # Define models to download
7
+ EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
8
+ CROSS_ENCODER_MODEL_NAME = "BAAI/bge-reranker-v2-m3"
9
+
10
+ def download_with_retries(repo_id, retries=5, delay=10):
11
+ """Downloads a model with retry logic."""
12
+ for i in range(retries):
13
+ try:
14
+ print(f"Downloading {repo_id} (Attempt {i+1}/{retries})...")
15
+ # resume_download=True ensures we don't start from scratch if interrupted
16
+ snapshot_download(repo_id=repo_id, resume_download=True)
17
+ print(f"Successfully downloaded {repo_id}")
18
+ return
19
+ except Exception as e:
20
+ print(f"Error downloading {repo_id}: {e}")
21
+ if i < retries - 1:
22
+ print(f"Retrying in {delay} seconds...")
23
+ time.sleep(delay)
24
+ else:
25
+ print(f"Failed to download {repo_id} after {retries} attempts.")
26
+ raise e
27
+
28
+ def download_models():
29
+ print(f"Downloading embedding model: {EMBEDDING_MODEL_NAME}")
30
+ download_with_retries(EMBEDDING_MODEL_NAME)
31
+
32
+ # Also initialize SentenceTransformer to ensure it caches correctly for the library
33
+ print(f"Initializing SentenceTransformer for {EMBEDDING_MODEL_NAME} to populate cache...")
34
+ try:
35
+ SentenceTransformer(EMBEDDING_MODEL_NAME)
36
+ except Exception as e:
37
+ print(f"Warning: Failed to initialize SentenceTransformer: {e}")
38
+
39
+ print(f"Downloading cross-encoder model: {CROSS_ENCODER_MODEL_NAME}")
40
+ download_with_retries(CROSS_ENCODER_MODEL_NAME)
41
+
42
+ # Initialize CrossEncoder to populate cache
43
+ print(f"Initializing CrossEncoder for {CROSS_ENCODER_MODEL_NAME} to populate cache...")
44
+ try:
45
+ CrossEncoder(CROSS_ENCODER_MODEL_NAME)
46
+ except Exception as e:
47
+ print(f"Warning: Failed to initialize CrossEncoder: {e}")
48
+
49
+ # Download GGUF model
50
+ llm_repo_id = "phureexd/qwen3_v2_gguf"
51
+ llm_filename = "unsloth.Q4_K_M.gguf"
52
+ print(f"Downloading LLM: {llm_filename} from {llm_repo_id}")
53
+ try:
54
+ hf_hub_download(repo_id=llm_repo_id, filename=llm_filename, local_dir=".", local_dir_use_symlinks=False)
55
+ print(f"Successfully downloaded {llm_filename}")
56
+ except Exception as e:
57
+ print(f"Error downloading LLM: {e}")
58
+ raise e
59
+
60
+ print("All models downloaded successfully.")
61
+
62
+ if __name__ == "__main__":
63
+ download_models()
graph.png ADDED
main.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
4
+
5
+
6
+ from unsloth import FastLanguageModel
7
+
8
+ print()
9
+ import time
10
+ from threading import Thread
11
+
12
+ import uvicorn
13
+ from fastapi import FastAPI
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import StreamingResponse
16
+ from transformers import TextIteratorStreamer
17
+
18
+ # Initialize FastAPI app
19
+ app = FastAPI()
20
+
21
+ # Enable CORS to allow frontend requests
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+ ###########################################################################
30
+ from langchain_huggingface import HuggingFaceEmbeddings
31
+
32
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
33
+
34
+ from langchain_chroma import Chroma
35
+
36
+ vector_store = Chroma(
37
+ embedding_function=embeddings,
38
+ persist_directory="C:/Users/LENOVO/Downloads/chroma_langchain_db_3", # Where to save data locally, remove if not necessary
39
+ )
40
+
41
+
42
+ # Model configuration
43
+ # model_name = "phureexd/qwen_model"
44
+ model_name = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit"
45
+ max_seq_length = 2048
46
+ dtype = None
47
+ load_in_4bit = True
48
+
49
+ # Load model and tokenizer
50
+ model, tokenizer = FastLanguageModel.from_pretrained(
51
+ model_name=model_name,
52
+ max_seq_length=max_seq_length,
53
+ dtype=dtype,
54
+ load_in_4bit=load_in_4bit,
55
+ )
56
+
57
+ FastLanguageModel.for_inference(model)
58
+
59
+
60
+ # Step 3: Prepare RAG messages
61
+ def prepare_rag_messages(messages, vector_store, k=2):
62
+ query = next(msg["content"] for msg in reversed(messages) if msg["role"] == "user")
63
+ print("this is query:\n", type(query), query)
64
+ docs = vector_store.similarity_search(query, k=k)
65
+ context = "\n\n".join(
66
+ f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in docs
67
+ )
68
+ print("this is context:\n", context)
69
+ system_message = messages[0]["content"] + "\n\nContext:\n" + context
70
+ rag_messages = [
71
+ {"role": "system", "content": system_message},
72
+ {"role": "user", "content": query},
73
+ ]
74
+ return rag_messages
75
+
76
+
77
+ # Endpoint to generate response using GET
78
+ @app.get("/generate")
79
+ async def generate(query: str):
80
+
81
+ messages = [
82
+ {
83
+ "role": "system",
84
+ "content": f"""You are a medical professional assistant. You will receive user queries along with relevant context retrieved via RAG.
85
+ Use the context if it is relevant. If not, rely on your own medical knowledge. If unsure, clearly state so.
86
+ Always respond in the same language used in the user's query. Keep responses clear, concise, and professional.
87
+
88
+ Extremely important: Answer in the same language as the user query.
89
+ """,
90
+ },
91
+ {"role": "user", "content": f"{query}"},
92
+ ]
93
+
94
+ rag_messages = prepare_rag_messages(messages, vector_store, k=2)
95
+
96
+ text = tokenizer.apply_chat_template(
97
+ rag_messages,
98
+ tokenize=False,
99
+ add_generation_prompt=True, # Must add for generation
100
+ enable_thinking=False,
101
+ )
102
+
103
+ inputs = tokenizer(text, return_tensors="pt").to(device=model.device)
104
+
105
+ def stream_response():
106
+ streamer = TextIteratorStreamer(
107
+ tokenizer, skip_prompt=True, skip_special_tokens=True
108
+ )
109
+ # * the recommended settings for reasoning inference are temperature = 0.6, top_p = 0.95, top_k = 20
110
+ # * For normal chat based inference, temperature = 0.7, top_p = 0.8, top_k = 20
111
+ # generate_kwargs = dict(
112
+ # **inputs,
113
+ # max_new_tokens=2048,
114
+ # do_sample=True,
115
+ # temperature=0.6,
116
+ # top_p=0.95,
117
+ # top_k=20,
118
+ # streamer=streamer,
119
+ # )
120
+ generate_kwargs = dict(
121
+ **inputs,
122
+ max_new_tokens=1024,
123
+ do_sample=True,
124
+ temperature=0.7,
125
+ top_p=0.8,
126
+ top_k=20,
127
+ streamer=streamer,
128
+ )
129
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
130
+ # thread.daemon = True
131
+ thread.start()
132
+
133
+ for new_text in streamer:
134
+ yield f"data: {new_text}\n\n"
135
+ # time.sleep(0.01)
136
+
137
+ return StreamingResponse(
138
+ stream_response(),
139
+ media_type="text/event-stream",
140
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
141
+ )
142
+
143
+
144
+ # Run the server
145
+ if __name__ == "__main__":
146
+ uvicorn.run(app, host="0.0.0.0", port=8000)
main_v2.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Suppress TensorFlow oneDNN optimization messages if not needed
4
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
5
+ # Disable ChromaDB telemetry to prevent log errors
6
+ os.environ["ANONYMIZED_TELEMETRY"] = "False"
7
+ import uvicorn
8
+ from fastapi import FastAPI
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import StreamingResponse
11
+ from langchain.retrievers import ContextualCompressionRetriever
12
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
13
+ from langchain.tools.retriever import create_retriever_tool
14
+ from langchain_chroma import Chroma
15
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
16
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
17
+ from langchain_huggingface import HuggingFaceEmbeddings
18
+ from langchain_ollama import ChatOllama
19
+ from langgraph.checkpoint.memory import MemorySaver
20
+ from langgraph.graph import END, MessagesState, StateGraph
21
+ from langgraph.prebuilt import ToolNode, tools_condition
22
+
23
+ # Set the device for HuggingFace models
24
+ device = "cpu"
25
+
26
+ # --- Configuration Constants ---
27
+ APP_HOST = "0.0.0.0"
28
+ APP_PORT = 7860
29
+
30
+ THREAD_ID = "global_health_chat_session" # Unique ID for the chat session
31
+
32
+ # Models and Paths
33
+ EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
34
+ CROSS_ENCODER_MODEL_NAME = "BAAI/bge-reranker-v2-m3"
35
+ LLM_MODEL_NAME = "custom-model" # Replace with your actual model, e.g., "hf.co/phureexd/qwen3_v2_gguf:Q4_K_M"
36
+ VECTOR_DB_PATH = "/app/chroma_db" if os.path.exists("/app/chroma_db") else "chroma_db"
37
+
38
+ # LLM Parameters
39
+ LLM_TEMPERATURE = 0.7
40
+ LLM_TOP_P = 0.8
41
+ LLM_TOP_K = 20
42
+ LLM_NUM_PREDICT = 512
43
+
44
+ # Retriever Parameters
45
+ RETRIEVER_SEARCH_K = 6 # Number of documents to fetch initially
46
+ RERANKER_TOP_N = 3 # Number of documents after reranking
47
+
48
+ # --- System Prompts ---
49
+
50
+ INITIAL_SYSTEM_MESSAGE = SystemMessage(
51
+ content="""
52
+ You are a health assistant designed to answer questions related to health, wellness, nutrition, exercise, symptoms, diseases, prevention, treatment, mental health, and medical advice. This explicitly includes general statements about feeling unwell or sick (e.g., "I'm sick", "I don't feel good"). For ANY query that falls into these categories, you MUST use the retrieve_health_info tool to fetch relevant information from the database before providing an answer. This ensures your responses are accurate and based on trusted sources. Do not answer health-related questions directly without using the tool, even if you think you know the answer.
53
+
54
+ If the query is clearly unrelated to health (e.g., general knowledge questions), you can answer directly without the tool.
55
+
56
+ **Important Guidelines:**
57
+ - If the query mentions or implies health, feeling unwell, sickness, treatment, symptoms, diseases, nutrition, exercise, mental health, or wellness, use the tool.
58
+ - Even if the query is only slightly related to health, or is a general statement about feeling unwell, use the tool to provide an informed answer.
59
+ - Always respond in the same language as the user's query.
60
+ - When in doubt, err on the side of using the tool.
61
+
62
+ **Examples:**
63
+
64
+ 1. **Health-Related (Use Tool):**
65
+ - User: "What are the symptoms of diabetes?"
66
+ - Assistant: [Uses retrieve_health_info tool] "Common symptoms of diabetes include frequent urination, excessive thirst, and fatigue."
67
+
68
+ 2. **Slightly Health-Related (Use Tool):**
69
+ - User: "Is it okay to exercise when I have a cold?"
70
+ - Assistant: [Uses retrieve_health_info tool] "Light exercise might be okay, but rest if you have a fever."
71
+
72
+ 3. **General Sickness Statement (Use Tool):**
73
+ - User: "I'm sick."
74
+ - Assistant: [Uses retrieve_health_info tool] "I'm sorry to hear you're not feeling well. Common advice includes resting and staying hydrated. If you have specific symptoms, I can try to provide more information."
75
+
76
+ 4. **Non-Health-Related (No Tool):**
77
+ - User: "What is the capital of France?"
78
+ - Assistant: "The capital of France is Paris."
79
+
80
+ 5. **Health-Related in Thai (Use Tool):**
81
+ - User: "อาการของโรคเบาหวานมีอะไรบ้าง?"
82
+ - Assistant: [Uses retrieve_health_info tool] "อาการทั่วไปของโรคเบาหวาน ได้แก่ ปัสสาวะบ่อย กระหายน้ำมาก และอ่อนเพลีย"
83
+
84
+ 6. **Non-Health-Related in Thai (No Tool):**
85
+ - User: "เมืองหลวงของฝรั่งเศสคืออะไร?"
86
+ - Assistant: "เมืองหลวงของฝรั่งเศสคือปารีส"
87
+ /no_think
88
+ """
89
+ )
90
+
91
+ RAG_SYSTEM_PROMPT_TEMPLATE = """
92
+ You are a health assistant for question-answering tasks.
93
+ Use the following pieces of retrieved documents to answer the question.
94
+ If you don't know the answer, say that you don't know.
95
+ Keep the answer concise and accurate.
96
+
97
+ **Extremely important: Answer in the same language as the user query.**
98
+
99
+ ### Retrieved documents (if applicable):
100
+ {docs_content}
101
+
102
+ ### Examples of the language model's responses:
103
+ **Example 1 (English):**
104
+ User: I feel a bit tired, what could it be?
105
+ Assistant: Fatigue can be caused by lack of sleep, stress, or dehydration. Ensure you get 7-8 hours of sleep and stay hydrated.
106
+
107
+ **Example 2 (English):**
108
+ User: Does coffee affect my health?
109
+ Assistant: Moderate coffee consumption can improve alertness but may cause insomnia or anxiety if overconsumed.
110
+
111
+ **Example 3 (Thai):**
112
+ User: ฉันรู้สึกเหนื่อยเล็กน้อย เกิดจากอะไรได้บ้าง?
113
+ Assistant: อาการเหนื่อยอาจเกิดจากการนอนหลับไม่เพียงพอ ความเครียด หรือภาวะขาดน้ำ ควรนอนหลับ 7-8 ชั่วโมงและดื่มน้ำให้เพียงพอ
114
+ /no_think
115
+ """
116
+
117
+ # --- Initialization of Langchain Components ---
118
+
119
+
120
+ def init_embeddings(model_name: str):
121
+ """Initializes HuggingFace embeddings."""
122
+ return HuggingFaceEmbeddings(model_name=model_name)
123
+
124
+
125
+ def init_vector_store(embedding_function, persist_directory: str):
126
+ """Initializes Chroma vector store."""
127
+ return Chroma(
128
+ embedding_function=embedding_function,
129
+ persist_directory=persist_directory,
130
+ )
131
+
132
+
133
+ def init_llm(
134
+ model_name: str, temperature: float, top_p: float, top_k: int, num_predict: int
135
+ ):
136
+ """Initializes ChatOllama LLM."""
137
+ return ChatOllama(
138
+ model=model_name,
139
+ temperature=temperature,
140
+ top_p=top_p,
141
+ top_k=top_k,
142
+ num_predict=num_predict,
143
+ )
144
+
145
+
146
+ def init_retriever_tool(
147
+ vector_store_instance,
148
+ cross_encoder_model_name: str,
149
+ base_retriever_k: int,
150
+ reranker_top_n: int,
151
+ ):
152
+ """Initializes the retriever tool with reranking."""
153
+ base_retriever = vector_store_instance.as_retriever(
154
+ search_kwargs={"k": base_retriever_k}
155
+ )
156
+
157
+ cross_encoder = HuggingFaceCrossEncoder(
158
+ model_name=cross_encoder_model_name,
159
+ model_kwargs={"device": device}, # Specify device if needed, e.g., "cuda"
160
+ )
161
+ reranker = CrossEncoderReranker(model=cross_encoder, top_n=reranker_top_n)
162
+
163
+ compression_retriever = ContextualCompressionRetriever(
164
+ base_compressor=reranker,
165
+ base_retriever=base_retriever,
166
+ )
167
+
168
+ return create_retriever_tool(
169
+ retriever=compression_retriever,
170
+ name="retrieve_health_info",
171
+ description=(
172
+ "Use this tool to retrieve relevant documents from the query related to health, "
173
+ "wellness, nutrition, exercise, symptoms, diseases, treatment, prevention, "
174
+ "mental health, or medical advice information from the database. "
175
+ "Even if the query is slightly related. "
176
+ f"Return the top {reranker_top_n} most relevant documents."
177
+ ),
178
+ response_format="content_and_artifact", # Ensures artifact contains Document objects
179
+ )
180
+
181
+
182
+ # Initialize components
183
+ print("Initializing Embeddings...")
184
+ embeddings = init_embeddings(EMBEDDING_MODEL_NAME)
185
+ print("Embeddings Initialized.")
186
+
187
+ print("Initializing Vector Store...")
188
+ vector_store = init_vector_store(embeddings, VECTOR_DB_PATH)
189
+ print("Vector Store Initialized.")
190
+
191
+ print("Initializing LLM...")
192
+ llm = init_llm(LLM_MODEL_NAME, LLM_TEMPERATURE, LLM_TOP_P, LLM_TOP_K, LLM_NUM_PREDICT)
193
+ print("LLM Initialized.")
194
+
195
+ print("Initializing Retriever Tool...")
196
+ retriever_tool = init_retriever_tool(
197
+ vector_store, CROSS_ENCODER_MODEL_NAME, RETRIEVER_SEARCH_K, RERANKER_TOP_N
198
+ )
199
+ print("Retriever Tool Initialized.")
200
+
201
+ # --- LangGraph Node Definitions ---
202
+
203
+
204
+ async def query_or_respond_node_logic(state: MessagesState):
205
+ """
206
+ Node function: Decides whether to call a tool for retrieval or respond directly.
207
+ Binds the retriever_tool to the LLM for this decision.
208
+ """
209
+ response = await llm.bind_tools([retriever_tool]).ainvoke(state["messages"])
210
+ return {"messages": [response]}
211
+
212
+
213
+ async def generate_rag_response_node_logic(state: MessagesState):
214
+ """
215
+ Node function: Generates a response using retrieved documents (if any).
216
+ """
217
+ # Extract the most recent contiguous block of tool messages
218
+ recent_tool_messages = []
219
+ for message in reversed(state["messages"]):
220
+ if message.type == "tool": # or isinstance(message, ToolMessage)
221
+ recent_tool_messages.append(message)
222
+ else:
223
+ break
224
+ tool_messages = recent_tool_messages[::-1]
225
+
226
+ # Format retrieved document content for the prompt
227
+ doc_strings = []
228
+ for tool_msg in tool_messages:
229
+ # Ensure artifact is a list of Langchain Document objects
230
+ if hasattr(tool_msg, "artifact") and isinstance(tool_msg.artifact, list):
231
+ for doc in tool_msg.artifact:
232
+ if hasattr(doc, "page_content") and hasattr(
233
+ doc, "metadata"
234
+ ): # Document structure check
235
+ source = doc.metadata.get("source", "Unknown source")
236
+ content = doc.page_content
237
+ doc_strings.append(f"Source: {source}\nContent: {content}")
238
+
239
+ docs_content = (
240
+ "\n\n".join(doc_strings)
241
+ if doc_strings
242
+ else "No relevant documents were found to answer the current question."
243
+ )
244
+
245
+ # Prepare messages for the generation LLM call (history + new system prompt with docs)
246
+ # Include human messages, initial system messages, and AI responses (not tool calls)
247
+ conversation_history_for_llm = [
248
+ msg
249
+ for msg in state["messages"]
250
+ if msg.type in ("human", "system") or (msg.type == "ai" and not msg.tool_calls)
251
+ ]
252
+
253
+ # Construct the system prompt with retrieved documents
254
+ current_system_prompt_content = RAG_SYSTEM_PROMPT_TEMPLATE.format(
255
+ docs_content=docs_content
256
+ )
257
+
258
+ prompt_for_generation = [
259
+ SystemMessage(content=current_system_prompt_content)
260
+ ] + conversation_history_for_llm
261
+
262
+ response = await llm.ainvoke(prompt_for_generation)
263
+ return {"messages": [response]}
264
+
265
+
266
+ # --- LangGraph Graph Construction ---
267
+
268
+
269
+ def create_lang_graph(checkpointer_instance):
270
+ """Creates and compiles the LangGraph."""
271
+ graph_builder = StateGraph(MessagesState)
272
+
273
+ # Define nodes
274
+ graph_builder.add_node("query_or_respond", query_or_respond_node_logic)
275
+ tools_node = ToolNode([retriever_tool]) # Define tool execution node
276
+ graph_builder.add_node("tools", tools_node)
277
+ graph_builder.add_node("generate_rag_response", generate_rag_response_node_logic)
278
+
279
+ # Define edges
280
+ graph_builder.set_entry_point("query_or_respond")
281
+ graph_builder.add_conditional_edges(
282
+ "query_or_respond",
283
+ tools_condition, # Prebuilt condition to check for tool calls
284
+ {END: END, "tools": "tools"},
285
+ )
286
+ graph_builder.add_edge("tools", "generate_rag_response")
287
+ graph_builder.add_edge("generate_rag_response", END)
288
+
289
+ return graph_builder.compile(checkpointer=checkpointer_instance)
290
+
291
+
292
+ # Initialize checkpointer and compile graph
293
+ memory_saver = MemorySaver()
294
+ graph = create_lang_graph(memory_saver)
295
+
296
+ # Optional: Save graph visualization
297
+ # try:
298
+ # graph.get_graph().draw_mermaid_png(output_file_path="graph.png")
299
+ # print("Graph visualization saved to graph.png")
300
+ # except Exception as e:
301
+ # print(f"Could not save graph visualization: {e}")
302
+
303
+
304
+ # --- FastAPI Application Setup ---
305
+ app = FastAPI(
306
+ title="Health Assistant API",
307
+ description="API for a health assistant using a retrieval-augmented generation approach.",
308
+ )
309
+
310
+ app.add_middleware(
311
+ CORSMiddleware,
312
+ allow_origins=["*"], # Allow all origins for simplicity; restrict in production
313
+ allow_credentials=True,
314
+ allow_methods=["*"],
315
+ allow_headers=["*"],
316
+ )
317
+
318
+ # --- FastAPI Endpoints ---
319
+
320
+
321
+ @app.get("/generate", summary="Generate a response from the health assistant model")
322
+ async def generate_endpoint(
323
+ query: str,
324
+ useRAG: bool = False,
325
+ thread_id: str = THREAD_ID,
326
+ ):
327
+ """
328
+ Handles a user query, streams back the assistant's responses.
329
+ - `query`: The user's question.
330
+ - `useRAG`: If true, forces the use of the retrieval tool via a system message.
331
+ - `thread_id`: Unique identifier for the conversation session.
332
+ """
333
+ print(f"Received query: '{query}', Force RAG: {useRAG}, Thread ID: {thread_id}")
334
+
335
+ config = {"configurable": {"thread_id": thread_id}}
336
+
337
+ # Prepare input messages for the graph
338
+ input_messages = []
339
+ current_checkpoint_tuple = memory_saver.get_tuple(config) # Check if history exists
340
+
341
+ if current_checkpoint_tuple is None: # No history, it's a new or cleared session
342
+ input_messages.append(INITIAL_SYSTEM_MESSAGE)
343
+ print("Starting new conversation: Added initial system message.")
344
+
345
+ if useRAG:
346
+ # This message is added to strongly encourage tool use for the current query,
347
+ # supplementing the INITIAL_SYSTEM_MESSAGE.
348
+ input_messages.append(
349
+ SystemMessage(
350
+ content="You MUST use the retrieve_health_info tool for this query even if it seems unrelated to health."
351
+ )
352
+ )
353
+ print("Forcing RAG for this query with an additional system message.")
354
+
355
+ input_messages.append(HumanMessage(content=query))
356
+ graph_input = {"messages": input_messages}
357
+
358
+ async def stream_response_events():
359
+ # graph.stream with stream_mode="messages" yields the ENTIRE list of messages
360
+ # in the current state each time a node completes.
361
+ async for messages_in_state in graph.astream(
362
+ graph_input, config, stream_mode="messages"
363
+ ):
364
+ if not messages_in_state:
365
+ continue
366
+
367
+ # Get the current message from the state
368
+ latest_message = messages_in_state[0]
369
+
370
+ if isinstance(latest_message, AIMessage):
371
+ if latest_message.content: # Final textual response
372
+ # print(
373
+ # f"Streaming AI content: {latest_message.content}"
374
+ # )
375
+ yield f"data: {latest_message.content}\n\n"
376
+ elif latest_message.tool_calls: # AI message requesting a tool call
377
+ print(f"AI requested Tool call: {latest_message.tool_calls}")
378
+ # You might want to send a status to the client, e.g., "Thinking..." or "Retrieving info..."
379
+ # yield f"event: tool_call\ndata: {json.dumps(latest_message.tool_calls)}\n\n"
380
+ elif isinstance(
381
+ latest_message, ToolMessage
382
+ ): # Message containing tool execution results
383
+ if latest_message.name == "retrieve_health_info" and hasattr(
384
+ latest_message, "artifact"
385
+ ):
386
+ print(f"Tool '{latest_message.name}' executed. Artifact content:")
387
+ if latest_message.artifact and isinstance(
388
+ latest_message.artifact, list
389
+ ):
390
+ # print every document in the artifact
391
+ source_list = set()
392
+ for doc in latest_message.artifact:
393
+ source = doc.metadata.get("source", "Unknown source")
394
+
395
+ if source != "Unknown source":
396
+ source_list.add(source)
397
+
398
+ print(f" Source: {source}\n Content: {doc.page_content}")
399
+ yield f"data: **Source:**{str(source_list)}\n\n"
400
+
401
+ return StreamingResponse(
402
+ stream_response_events(),
403
+ media_type="text/event-stream",
404
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
405
+ )
406
+
407
+
408
+ @app.get("/clear", summary="Clear conversation history")
409
+ async def clear_conversation_endpoint(thread_id: str = THREAD_ID):
410
+ """Clears the conversation history for the specified thread_id."""
411
+ try:
412
+ # Note: MemorySaver in some versions might not support explicit deletion easily via public API
413
+ # This is a best-effort attempt or placeholder for actual persistence deletion
414
+ # If using a real DB checkpointer, you would delete rows here.
415
+ # For MemorySaver, we might just need to reset the state or let it be if it's per-request instance (it's not here).
416
+ # Actually, MemorySaver stores in a dict. We can try accessing it if we really need to clear.
417
+ if hasattr(memory_saver, "storage"):
418
+ if thread_id in memory_saver.storage:
419
+ del memory_saver.storage[thread_id]
420
+
421
+ print(f"Conversation history cleared for thread_id: {thread_id}")
422
+ return {"status": "success", "message": "Conversation history cleared."}
423
+ except Exception as e:
424
+ print(f"Error clearing conversation history for thread_id {thread_id}: {e}")
425
+ return {"status": "error", "message": f"Failed to clear history: {e}"}
426
+
427
+
428
+ # --- Main Execution ---
429
+ if __name__ == "__main__":
430
+ print(f"Starting Health Assistant API on {APP_HOST}:{APP_PORT}")
431
+ uvicorn.run(app, host=APP_HOST, port=APP_PORT)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.12
2
+ uvicorn==0.34.2
3
+ langchain==0.3.25
4
+ langchain-core==0.3.58
5
+ langchain-chroma==0.2.3
6
+ langchain-huggingface==0.1.2
7
+ langchain-ollama==0.3.2
8
+ langchain-community==0.3.23
9
+ langgraph==0.4.1
10
+ chromadb==0.6.3
11
+ huggingface-hub==0.30.2
12
+ sentence-transformers==3.4.1
13
+ transformers==4.51.3
14
+ aiocron==1.8
15
+ aiohttp==3.11.11
start-ollama.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # Start Ollama in the background
4
+ ollama serve &
5
+
6
+ # Wait for Ollama to start
7
+ sleep 5
8
+
9
+ # Create a custom model using the Modelfile
10
+ ollama create custom-model -f /app/Modelfile.local
11
+