Instructions to use shrishSVaidya/medgemma-1.5-mm-progression-module4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use shrishSVaidya/medgemma-1.5-mm-progression-module4 with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("google/medgemma-1.5-4b-it") model = PeftModel.from_pretrained(base_model, "shrishSVaidya/medgemma-1.5-mm-progression-module4") - Notebooks
- Google Colab
- Kaggle
MedGemma 1.5 - Multiple Myeloma Progression Tracking (Module 4)
π Model Overview
This repository contains a PEFT/LoRA adapter fine-tuned on MedGemma 1.5 4B-IT. It is specifically designed to analyze raw clinical text, extract longitudinal M-Spike metrics (Serum Protein Electrophoresis), and assess if the data indicates rapid disease progression in Multiple Myeloma patients.
This adapter was developed as part of a broader agentic AI application catering to Multiple Myeloma patients. It acts as Module 4, operating alongside upstream risk assessment and vision modules to feed structured progression data into a RAG-enabled clinical dashboard.
π Associated Code Repository
The complete source code for data preparation, training and validating this adapter, as well as the full Agentic-AI pipeline, can be found on GitHub: here
Base Model Dependency
This is an adapter model. It requires the base weights from Google's MedGemma 1.5 4B-IT.
β οΈ License and Terms of Use
- LoRA Adapter Weights: The adapter weights and associated code in this repository are open-sourced under the Apache 2.0 license.
- Base Model: To use this adapter, you must agree to the Google Health AI Developer Foundations Terms of Use to access the underlying MedGemma 1.5 weights.
- Clinical Disclaimer: This model is for educational and research purposes only. It is not a medical device, is not intended for clinical use, and should not be used to diagnose, treat, or offer medical advice for any disease or condition.
π» How to Use
Because this model uses the MedGemma vision-language architecture, it is strictly recommended to load the model in 4-bit NF4 quantization and utilize a dummy image tensor to stabilize the cross-attention vision layers during text-only generation.
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import torch
# 1. Load Base Model in 4-bit NF4
model_id = "google/medgemma-1.5-4b-it"
processor = AutoProcessor.from_pretrained(model_id)
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
base_model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
quantization_config=quant_config
)
# 2. Load this LoRA Adapter
model = PeftModel.from_pretrained(base_model, "shrish/medgemma-1.5-mm-progression-module4")
# 3. Format Prompt
patient_id = "" # fill in the patient id
Timeline = "" # have the subsequent lab test metrics like: - Day -17: Platelets: 329.0 x10^9 cells/L, Hemoglobin: 6.7 mmol/L, Creatinine: 68.07 umol/L, M Protein: 2.98 g/dL, Calcium: 2.45 mmol/L
- Day 78: Calcium: 2.25 mmol/L, M Protein: 1.97 g/dL, Hemoglobin: 7.01 mmol/L, Platelets: 286.0 x10^9 cells/L, Creatinine: 53.92 umol/L, etc
prompt = f"Review the following longitudinal biomarker history for patient {patient_id}. Predict the disease trajectory: Is this patient showing biochemical progression toward Active Myeloma?
Timeline:{Timeline}"
messages = [{"role": "user", "content": prompt}]
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=formatted_prompt,
return_tensors="pt",
padding=True
).to(model.device)
inputs.pop("token_type_ids", None)
# 4. Generate
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=300, do_sample=False)
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[0, input_length:]
pred_text = processor.decode(generated_tokens, skip_special_tokens=True).strip()
print("Model Prediction:\n", pred_text)
- Downloads last month
- 1
Model tree for shrishSVaidya/medgemma-1.5-mm-progression-module4
Base model
google/medgemma-1.5-4b-it