| # import os | |
| # import gradio as gr | |
| # from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # import torch | |
| # from typing import List, Dict | |
| # import logging | |
| # # Set up logging to help us debug model loading and inference | |
| # logging.basicConfig(level=logging.INFO) | |
| # logger = logging.getLogger(__name__) | |
| # class MedicalAssistant: | |
| # def __init__(self): | |
| # """Initialize the medical assistant with model and tokenizer""" | |
| # try: | |
| # logger.info("Starting model initialization...") | |
| # # Model configuration - adjust these based on your available compute | |
| # self.model_name = "mradermacher/Llama3-Med42-8B-GGUF" | |
| # self.max_length = 1048 | |
| # self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # logger.info(f"Using device: {self.device}") | |
| # # Load tokenizer first - this is typically faster and can catch issues early | |
| # logger.info("Loading tokenizer...") | |
| # self.tokenizer = AutoTokenizer.from_pretrained( | |
| # self.model_name, | |
| # padding_side="left", | |
| # trust_remote_code=True | |
| # ) | |
| # # Set padding token if not set | |
| # if self.tokenizer.pad_token is None: | |
| # self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # # Load model with memory optimizations | |
| # logger.info("Loading model...") | |
| # self.model = AutoModelForCausalLM.from_pretrained( | |
| # self.model_name, | |
| # torch_dtype=torch.float16, | |
| # device_map="auto", | |
| # load_in_8bit=True, | |
| # trust_remote_code=True | |
| # ) | |
| # logger.info("Model initialization completed successfully!") | |
| # except Exception as e: | |
| # logger.error(f"Error during initialization: {str(e)}") | |
| # raise | |
| # def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: | |
| # """Generate a response to the user's message""" | |
| # try: | |
| # # Prepare the prompt | |
| # system_prompt = """You are a medical AI assistant. Respond to medical queries | |
| # professionally and accurately. If you're unsure, always recommend consulting | |
| # with a healthcare provider.""" | |
| # # Combine system prompt, chat history, and current message | |
| # full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:" | |
| # # Tokenize input | |
| # inputs = self.tokenizer( | |
| # full_prompt, | |
| # return_tensors="pt", | |
| # padding=True, | |
| # truncation=True, | |
| # max_length=self.max_length | |
| # ).to(self.device) | |
| # # Generate response | |
| # with torch.no_grad(): | |
| # outputs = self.model.generate( | |
| # **inputs, | |
| # max_new_tokens=512, | |
| # do_sample=True, | |
| # temperature=0.7, | |
| # top_p=0.95, | |
| # pad_token_id=self.tokenizer.pad_token_id, | |
| # repetition_penalty=1.1 | |
| # ) | |
| # # Decode and clean up response | |
| # response = self.tokenizer.decode( | |
| # outputs[0], | |
| # skip_special_tokens=True | |
| # ) | |
| # # Extract just the assistant's response | |
| # response = response.split("Assistant:")[-1].strip() | |
| # return response | |
| # except Exception as e: | |
| # logger.error(f"Error during response generation: {str(e)}") | |
| # return f"I apologize, but I encountered an error. Please try again." | |
| # # Initialize the assistant | |
| # assistant = None | |
| # def initialize_assistant(): | |
| # """Initialize the assistant and handle any errors""" | |
| # global assistant | |
| # try: | |
| # assistant = MedicalAssistant() | |
| # return True | |
| # except Exception as e: | |
| # logger.error(f"Failed to initialize assistant: {str(e)}") | |
| # return False | |
| # def chat_response(message: str, history: List[Dict]): | |
| # """Handle chat messages and return responses""" | |
| # global assistant | |
| # # Check if assistant is initialized | |
| # if assistant is None: | |
| # if not initialize_assistant(): | |
| # return "I apologize, but I'm currently unavailable. Please try again later." | |
| # try: | |
| # return assistant.generate_response(message, history) | |
| # except Exception as e: | |
| # logger.error(f"Error in chat response: {str(e)}") | |
| # return "I encountered an error. Please try again." | |
| # # Create Gradio interface | |
| # demo = gr.ChatInterface( | |
| # fn=chat_response, | |
| # title="Medical Assistant (Test Version)", | |
| # description="""This is a test version of the medical assistant. | |
| # Please use it to verify basic functionality.""", | |
| # examples=[ | |
| # "What are the symptoms of malaria?", | |
| # "How can I prevent type 2 diabetes?", | |
| # "What should I do for a mild headache?" | |
| # ], | |
| # # retry_btn=None, | |
| # # undo_btn=None, | |
| # # clear_btn="Clear" | |
| # ) | |
| # # Launch the interface | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import os | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| from typing import List, Dict | |
| import logging | |
| import traceback | |
| # Set up logging to help us track what's happening | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class MedicalAssistant: | |
| def __init__(self): | |
| """ | |
| Initialize the medical assistant with the Llama3-Med42 model. | |
| This model is specifically trained on medical data and quantized to 4-bit precision | |
| for better memory efficiency while maintaining good performance. | |
| """ | |
| try: | |
| logger.info("Starting model initialization...") | |
| # Updated model to use Llama3-Med42 | |
| self.model_name = "emircanerol/Llama3-Med42-8B-4bit" | |
| self.max_length = 2048 | |
| # Initialize the pipeline for simplified text generation | |
| # The pipeline handles tokenizer and model loading automatically | |
| logger.info("Initializing pipeline...") | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model=self.model_name, | |
| token=os.getenv('HUGGING_FACE_TOKEN'), | |
| device_map="auto", | |
| torch_dtype=torch.float16, # Use half precision for 4-bit model | |
| load_in_4bit=True # Enable 4-bit quantization | |
| ) | |
| # Load tokenizer separately for more control over text processing | |
| logger.info("Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| token=os.getenv('HUGGING_FACE_TOKEN'), | |
| trust_remote_code=True | |
| ) | |
| # Ensure proper padding token configuration | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info("Medical Assistant initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"Initialization failed: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: | |
| """ | |
| Generate a response using the Llama3-Med42 pipeline. | |
| This method formats the conversation history and generates appropriate medical responses. | |
| """ | |
| try: | |
| logger.info("Preparing message for generation") | |
| # Create a medical context-aware prompt | |
| system_prompt = """You are a medical AI assistant based on Llama3-Med42, | |
| specifically trained on medical knowledge. Provide accurate, professional | |
| medical guidance while acknowledging limitations. Always recommend | |
| consulting healthcare providers for specific medical advice.""" | |
| # Format the conversation for the model | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": message} | |
| ] | |
| # Add chat history if available | |
| if chat_history: | |
| for chat in chat_history: | |
| messages.append({ | |
| "role": "user" if chat["role"] == "user" else "assistant", | |
| "content": chat["content"] | |
| }) | |
| logger.info("Generating response") | |
| # Generate response using the pipeline | |
| response = self.pipe( | |
| messages, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| repetition_penalty=1.1 | |
| )[0]["generated_text"] | |
| # Clean up the response by extracting the last assistant message | |
| response = response.split("assistant:")[-1].strip() | |
| logger.info("Response generated successfully") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error during response generation: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return f"I apologize, but I encountered an error: {str(e)}" | |
| # Initialize the assistant | |
| assistant = None | |
| def initialize_assistant(): | |
| """Initialize the assistant with proper error handling""" | |
| global assistant | |
| try: | |
| logger.info("Attempting to initialize assistant") | |
| assistant = MedicalAssistant() | |
| logger.info("Assistant initialized successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize assistant: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return False | |
| def chat_response(message: str, history: List[Dict]): | |
| """Handle chat interactions with error recovery""" | |
| global assistant | |
| if assistant is None: | |
| logger.info("Assistant not initialized, attempting initialization") | |
| if not initialize_assistant(): | |
| return "I apologize, but I'm currently unavailable. Please try again later." | |
| try: | |
| return assistant.generate_response(message, history) | |
| except Exception as e: | |
| logger.error(f"Error in chat response: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return f"I encountered an error: {str(e)}" | |
| # Create the Gradio interface | |
| demo = gr.ChatInterface( | |
| fn=chat_response, | |
| title="Medical Assistant (Llama3-Med42)", | |
| description="""This medical assistant is powered by NURSEOGE, | |
| a model specifically trained on medical knowledge. It provides | |
| guidance and information about health-related queries while | |
| maintaining professional medical standards.""", | |
| examples=[ | |
| "What are the symptoms of malaria?", | |
| "How can I prevent type 2 diabetes?", | |
| "What should I do for a mild headache?" | |
| ] | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| logger.info("Starting the application") | |
| demo.launch() |