import sys import os import json import base64 import io import torch from fastapi import FastAPI, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from PIL import Image from huggingface_hub import snapshot_download, login from transformers import AutoProcessor, AutoModelForImageTextToText HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) print("Authenticated with HF token.", flush=True) else: print("WARNING: HF_TOKEN not set — gated models will fail.", flush=True) # --------------------------------------------------------------------------- # MedImageInsight — CLIP-style encoder for zero-shot label scoring # --------------------------------------------------------------------------- print("Downloading MedImageInsights repo...", flush=True) repo_path = snapshot_download("lion-ai/MedImageInsights") print(f"Downloaded to: {repo_path}", flush=True) sys.path.insert(0, repo_path) from medimageinsightmodel import MedImageInsight # noqa: E402 model_dir = os.path.join(repo_path, "2024.09.27") print("Loading MedImageInsight...", flush=True) classifier = MedImageInsight( model_dir=model_dir, vision_model_name="medimageinsigt-v1.0.0.pt", language_model_name="language_model.pth", ) classifier.load_model() print("MedImageInsight ready.", flush=True) # --------------------------------------------------------------------------- # MedGemma — generative VLM for free-text image description # --------------------------------------------------------------------------- MEDGEMMA_ID = "google/medgemma-1.5-4b-it" print("Loading MedGemma processor...", flush=True) gemma_processor = AutoProcessor.from_pretrained(MEDGEMMA_ID) print("Loading MedGemma model (bfloat16)...", flush=True) gemma_model = AutoModelForImageTextToText.from_pretrained( MEDGEMMA_ID, torch_dtype=torch.bfloat16, device_map="auto", ) gemma_model.eval() print("MedGemma ready.", flush=True) # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Medical Image Analysis API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) def _encode_image(data: bytes) -> str: """Convert raw image bytes → base64 PNG string for MedImageInsight.""" img = Image.open(io.BytesIO(data)).convert("RGB") buf = io.BytesIO() img.save(buf, format="PNG") return base64.encodebytes(buf.getvalue()).decode("utf-8") def _scores_to_list(scores: dict) -> list: return [{"label": k, "score": round(float(v), 6)} for k, v in scores.items()] # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/") def root(): return RedirectResponse(url="/health") @app.get("/health") def health(): return {"status": "ok"} @app.post("/classify") async def classify( image: UploadFile = File(...), labels: str = Form(...), ): """Zero-shot classification via MedImageInsight. Scores sum to ~1 (softmax).""" labels_list = json.loads(labels) img_b64 = _encode_image(await image.read()) results = classifier.predict([img_b64], labels_list, multilabel=False) return {"results": _scores_to_list(results[0])} @app.post("/multilabel") async def multilabel( image: UploadFile = File(...), labels: str = Form(...), ): """Multi-label classification via MedImageInsight. Each score is independent (sigmoid).""" labels_list = json.loads(labels) img_b64 = _encode_image(await image.read()) results = classifier.predict([img_b64], labels_list, multilabel=True) return {"results": _scores_to_list(results[0])} @app.post("/describe") async def describe( image: UploadFile = File(...), prompt: str = Form(default="Describe the medical findings visible in this image."), ): """Free-text image description via MedGemma 1.5-4B.""" img = Image.open(io.BytesIO(await image.read())).convert("RGB") messages = [ { "role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": prompt}, ], } ] inputs = gemma_processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(gemma_model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = gemma_model.generate( **inputs, max_new_tokens=512, do_sample=False, ) generation = generation[0][input_len:] description = gemma_processor.decode(generation, skip_special_tokens=True) return {"description": description}