| """ |
| Custom Handler for MORBID v0.2.0 Insurance AI |
| HuggingFace Inference Endpoints - Mistral Small 22B Fine-tuned |
| """ |
|
|
| from typing import Dict, List, Any |
| import os |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize the handler with model and tokenizer |
| |
| Args: |
| path: Path to the model directory |
| """ |
| |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=dtype, |
| device_map="auto", |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| self.system_prompt = """You are Morbi, an expert AI assistant specializing in health and life insurance, actuarial science, and risk analysis. You are: |
| |
| 1. KNOWLEDGEABLE: You have deep expertise in: |
| - Life insurance products (term, whole, universal, variable) |
| - Health insurance (medical, dental, disability, LTC) |
| - Actuarial mathematics (mortality tables, interest theory, reserving) |
| - Underwriting and risk classification |
| - Claims analysis and management |
| - Regulatory compliance (state, federal, NAIC) |
| - ICD-10 medical codes and cause-of-death classification |
| |
| 2. CONVERSATIONAL: You communicate naturally and warmly while maintaining professionalism. |
| |
| 3. ACCURATE: You provide factual, well-reasoned responses. You never make up statistics. |
| |
| 4. HELPFUL: You aim to assist users effectively with actionable information.""" |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process the inference request |
| |
| Args: |
| data: Dictionary containing the input data |
| - inputs (str or list): The input text(s) |
| - parameters (dict): Generation parameters |
| |
| Returns: |
| List of generated responses |
| """ |
| |
| inputs = data.get("inputs", "") |
| parameters = data.get("parameters", {}) |
| |
| |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| elif not isinstance(inputs, list): |
| inputs = [str(inputs)] |
| |
| |
| generation_params = { |
| "max_new_tokens": parameters.get("max_new_tokens", 512), |
| "temperature": parameters.get("temperature", 0.7), |
| "top_p": parameters.get("top_p", 0.9), |
| "do_sample": parameters.get("do_sample", True), |
| "repetition_penalty": parameters.get("repetition_penalty", 1.1), |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| } |
| |
| |
| results = [] |
| for input_text in inputs: |
| |
| prompt = self._format_prompt(input_text) |
| |
| |
| inputs_tokenized = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=4096 |
| ).to(self.model.device) |
| |
| |
| |
| bad_words_ids = [] |
| try: |
| |
| role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"] |
| tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids |
| |
| for ids in tokenized: |
| if isinstance(ids, list) and len(ids) > 0: |
| bad_words_ids.append(ids) |
| except Exception: |
| pass |
|
|
| decoding_kwargs = { |
| **generation_params, |
| |
| "no_repeat_ngram_size": 3, |
| } |
| if bad_words_ids: |
| decoding_kwargs["bad_words_ids"] = bad_words_ids |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs_tokenized, |
| **decoding_kwargs |
| ) |
| |
| |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| response = self._extract_response(generated_text, prompt) |
| response = self._truncate_at_stops(response) |
| |
| results.append({ |
| "generated_text": response, |
| "conversation": { |
| "user": input_text, |
| "assistant": response |
| } |
| }) |
| |
| return results |
| |
| def _format_prompt(self, user_input: str) -> str: |
| """ |
| Format the user input into Mistral Instruct format |
| |
| Args: |
| user_input: The user's message |
| |
| Returns: |
| Formatted prompt string in Mistral format |
| """ |
| |
| return f"<s>[INST] {self.system_prompt}\n\n{user_input} [/INST]" |
| |
| def _extract_response(self, generated_text: str, prompt: str) -> str: |
| """ |
| Extract only the assistant's response from the generated text |
| |
| Args: |
| generated_text: Full generated text including prompt |
| prompt: The original prompt |
| |
| Returns: |
| Just the assistant's response |
| """ |
| |
| if "[/INST]" in generated_text: |
| response = generated_text.split("[/INST]")[-1].strip() |
| elif generated_text.startswith(prompt): |
| response = generated_text[len(prompt):].strip() |
| else: |
| response = generated_text.strip() |
| |
| |
| response = response.replace("</s>", "").strip() |
| |
| |
| if not response: |
| response = "I'm here to help! Could you please rephrase your question?" |
| |
| return response |
|
|
| def _truncate_at_stops(self, text: str) -> str: |
| """Truncate model output at conversation stop markers.""" |
| |
| stop_markers = [ |
| "\n[INST]", "[INST]", "</s>", "<s>", |
| "\nHuman:", "\nUser:", "\nAssistant:", |
| ] |
| cut_index = None |
| for marker in stop_markers: |
| idx = text.find(marker) |
| if idx != -1: |
| cut_index = idx if cut_index is None else min(cut_index, idx) |
| if cut_index is not None: |
| text = text[:cut_index].rstrip() |
| |
| if len(text) > 2000: |
| text = text[:2000].rstrip() |
| return text |
|
|