| --- |
| license: apache-2.0 |
| tags: |
| - mental-health |
| - diagnosis |
| - text-generation |
| - gemma |
| - qlora |
| - transformers |
| - huggingface |
| datasets: |
| - Jaamie/mental-health-custom-dataset |
| pipeline_tag: text-generation |
| language: |
| - en |
| base_model: google/gemma-2-9b-it |
| library_name: peft |
| --- |
| |
| # π§ Gemma Mental Health QLoRA v2 |
|
|
| A fine-tuned version of `google/gemma-2-9b-it` for **mental health diagnosis** using instruction-style QLoRA tuning. This model takes in user statements and predicts the most likely mental disorder in a structured dialogue format. |
|
|
| --- |
|
|
| ## π§ Model Details |
|
|
| - **Base Model**: [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it) |
| - **Fine-Tuning Method**: QLoRA (4-bit quantization with `bitsandbytes`) |
| - **Tokenizer**: β
Included |
| - **LoRA Target Modules**: `["q_proj", "k_proj", "v_proj", "o_proj"]` |
| - **Sequence Format**: |
|
|
| ## Output format |
|
|
| - User: <statement> Diagnosed Mental Disorder: <Predicted_Mental_Health> |
|
|
| --- |
|
|
| ## π§ͺ Use Cases |
|
|
| - π§ Mental health Q&A assistant |
| - π¨οΈ Conversational diagnosis suggestion |
| - π NLP research and experimentation |
|
|
| > β οΈ **Disclaimer**: This model is for research and educational purposes **only**. It is **not** intended for use in real-world clinical diagnosis without medical supervision. |
|
|
| --- |
|
|
| ## π» How to Use |
|
|
| ```python |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
| import torch |
| |
| # Load tokenizer and base + adapter model |
| tokenizer = AutoTokenizer.from_pretrained("Jaamie/gemma_mental_health_qlora_v2") |
| base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", torch_dtype=torch.float16) |
| model = PeftModel.from_pretrained(base_model, "Jaamie/gemma_mental_health_qlora_v2") |
| |
| # Inference example |
| prompt = "User: I can't sleep and my thoughts are spiraling out of control.\nDiagnosed Mental Disorder:" |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
| |
| with torch.no_grad(): |
| outputs = model.generate(**inputs, max_new_tokens=30) |
| |
| print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
| |
| |
| ποΈ Training Details |
| Epochs: 2 |
| |
| Batch Size: 4 (with gradient_accumulation_steps = 2) |
| |
| Max Length: 512 |
| |
| Quantization: 4-bit QLoRA (NF4) with bitsandbytes |
| |
| Precision: bf16 |
| |
| |
| # Evaluation Results |
| |
| Metric Score |
| Training Loss 3.74 |
| Validation Loss 3.79 |
| Total Examples ~22,000 |
| |
| |
| The LLM has been trained on a sample of data from the dataset containing balanced instruction-style dataset with labeled disorders. |
| |
| Mental Health Class Sample Count |
| Depression 4,000 |
| Anxiety 4,000 |
| Suicidal Thoughts 3,000 |
| Personality Disorder 2,000 |
| Bipolar 2,000 |
| Stress 2,000 |
| Normal 5,000 |
| |
| # Contact |
| Created by Jaamie Maarsh Joy Martin |
| |
| π https://www.linkedin.com/in/jaamie-maarsh-joy-martin/ |
| |
| π§ jaamiemaarsh@gmail.com |
| |
| |