self-trained2 / main.py
DeepImagix's picture
Update main.py
473adb8 verified
raw
history blame
5.13 kB
import os
import json
import joblib
import pandas as pd
import logging
from fastapi import FastAPI, Form, HTTPException
import httpx
# Enable debug logging
logging.basicConfig(level=logging.DEBUG)
# --- 1. Basic Setup & Configuration ---
app = FastAPI(title="NeuraPrompt AI (Final Working Version)")
MASTER_AI_ID = "neurones_self"
USER_MODELS_DIR = "/data/user_models_data"
os.makedirs(USER_MODELS_DIR, exist_ok=True)
# Your Groq API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
# --- Pydantic Model to define the expected JSON structure ---
# This tells FastAPI exactly what kind of JSON to expect.
from pydantic import BaseModel
class ChatMessage(BaseModel):
user_id: str
message: str
# --- 2. Helper Functions ---
def get_ai_paths(ai_id: str = MASTER_AI_ID):
ai_dir = os.path.join(USER_MODELS_DIR, ai_id)
os.makedirs(ai_dir, exist_ok=True)
return {
"model_path": os.path.join(ai_dir, "matcher_model.joblib"),
"data_path": os.path.join(ai_dir, "training_pairs.csv"),
"responses_path": os.path.join(ai_dir, "responses.json")
}
async def train_local_ai(prompt: str, reply: str):
paths = get_ai_paths()
if os.path.exists(paths["responses_path"]):
with open(paths["responses_path"], 'r') as f: responses = json.load(f)
else: responses = []
if reply not in responses:
responses.append(reply)
with open(paths["responses_path"], 'w') as f: json.dump(responses, f)
reply_index = responses.index(reply)
new_data = pd.DataFrame([{"prompt": prompt, "label": reply_index}])
if os.path.exists(paths["data_path"]):
new_data.to_csv(paths["data_path"], mode='a', header=False, index=False)
else:
new_data.to_csv(paths["data_path"], mode='w', header=True, index=False)
df = pd.read_csv(paths["data_path"])
if len(df['label'].unique()) < 2: return
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
model_pipeline = Pipeline([
('tfidf', TfidfVectorizer()),
('clf', SGDClassifier(loss='modified_huber', random_state=42)),
])
model_pipeline.fit(df['prompt'], df['label'])
joblib.dump(model_pipeline, paths["model_path"])
async def get_groq_reply(user_prompt: str):
if not GROQ_API_KEY:
print("❌ Groq API key not set.")
return None
api_url = "https://api.groq.com/openai/v1/chat/completions"
headers = {"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"}
payload = {
"model": "llama3-8b-8192",
"messages": [
{"role": "system", "content": "You are NeuraPrompt AI, a helpful assistant created by Toxic Dee Modder from South Africa. Be friendly and direct."},
{"role": "user", "content": user_prompt}
]
}
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(api_url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Unexpected Groq error: {str(e)}")
return None
# --- 3. Main API Endpoints ---
@app.get("/")
def read_root():
return {"message": "NeuraPrompt AI Backend is Online."}
# +++ THIS IS THE SIMPLIFIED AND CORRECTED ENDPOINT +++
@app.post("/chat/")
async def chat(payload: ChatMessage):
"""This endpoint now ONLY accepts a JSON body with a user_id and a message."""
user_id = payload.user_id
user_message = payload.message.strip()
paths = get_ai_paths()
# Memory check
if os.path.exists(paths["model_path"]):
from sentence_transformers import util
model = joblib.load(paths["model_path"])
with open(paths["responses_path"], 'r') as f: responses = json.load(f)
# This part assumes your model and vectorizer are saved together in a pipeline
# If your model file is just the classifier, this needs adjustment.
# For simplicity, we'll assume the pipeline is what's saved.
try:
# We cannot get probabilities easily from the current matcher model
# So we will check for a very high similarity score instead.
# This requires a different approach than predict_proba
# This is a placeholder for the similarity logic which requires the sentence transformer model
# For now, we will proceed to Groq and focus on fixing the request issue.
pass # Skipping memory check for now to ensure base functionality
except Exception as e:
print("Memory prediction failed:", e)
# Fallback to Groq since memory check is complex with the current model
ai_reply = await get_groq_reply(user_message)
if ai_reply:
await train_local_ai(prompt=user_message, reply=ai_reply)
return {"response": ai_reply}
else:
raise HTTPException(status_code=503, detail="The Groq AI service is currently unavailable or failed.")