h3ir commited on
Commit
d7d1415
·
verified ·
1 Parent(s): 4ec6988

Add custom inference handler

Browse files
Files changed (1) hide show
  1. handler.py +204 -0
handler.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Handler for MORBID v0.2.0 Insurance AI
3
+ HuggingFace Inference Endpoints - Mistral Small 22B Fine-tuned
4
+ """
5
+
6
+ from typing import Dict, List, Any
7
+ import os
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+
12
+ class EndpointHandler:
13
+ def __init__(self, path: str = ""):
14
+ """
15
+ Initialize the handler with model and tokenizer
16
+
17
+ Args:
18
+ path: Path to the model directory
19
+ """
20
+ # Load tokenizer and model
21
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
22
+
23
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
24
+ self.model = AutoModelForCausalLM.from_pretrained(
25
+ path,
26
+ torch_dtype=dtype,
27
+ device_map="auto",
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ # Set padding token if not already set
32
+ if self.tokenizer.pad_token is None:
33
+ self.tokenizer.pad_token = self.tokenizer.eos_token
34
+
35
+ # System prompt for Morbi v0.2.0
36
+ self.system_prompt = """You are Morbi, an expert AI assistant specializing in health and life insurance, actuarial science, and risk analysis. You are:
37
+
38
+ 1. KNOWLEDGEABLE: You have deep expertise in:
39
+ - Life insurance products (term, whole, universal, variable)
40
+ - Health insurance (medical, dental, disability, LTC)
41
+ - Actuarial mathematics (mortality tables, interest theory, reserving)
42
+ - Underwriting and risk classification
43
+ - Claims analysis and management
44
+ - Regulatory compliance (state, federal, NAIC)
45
+ - ICD-10 medical codes and cause-of-death classification
46
+
47
+ 2. CONVERSATIONAL: You communicate naturally and warmly while maintaining professionalism.
48
+
49
+ 3. ACCURATE: You provide factual, well-reasoned responses. You never make up statistics.
50
+
51
+ 4. HELPFUL: You aim to assist users effectively with actionable information."""
52
+
53
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
54
+ """
55
+ Process the inference request
56
+
57
+ Args:
58
+ data: Dictionary containing the input data
59
+ - inputs (str or list): The input text(s)
60
+ - parameters (dict): Generation parameters
61
+
62
+ Returns:
63
+ List of generated responses
64
+ """
65
+ # Extract inputs
66
+ inputs = data.get("inputs", "")
67
+ parameters = data.get("parameters", {})
68
+
69
+ # Handle both string and list inputs
70
+ if isinstance(inputs, str):
71
+ inputs = [inputs]
72
+ elif not isinstance(inputs, list):
73
+ inputs = [str(inputs)]
74
+
75
+ # Set default generation parameters (optimized for Mistral Small 22B)
76
+ generation_params = {
77
+ "max_new_tokens": parameters.get("max_new_tokens", 512),
78
+ "temperature": parameters.get("temperature", 0.7),
79
+ "top_p": parameters.get("top_p", 0.9),
80
+ "do_sample": parameters.get("do_sample", True),
81
+ "repetition_penalty": parameters.get("repetition_penalty", 1.1),
82
+ "pad_token_id": self.tokenizer.pad_token_id,
83
+ "eos_token_id": self.tokenizer.eos_token_id,
84
+ }
85
+
86
+ # Process each input
87
+ results = []
88
+ for input_text in inputs:
89
+ # Format the prompt with conversational context
90
+ prompt = self._format_prompt(input_text)
91
+
92
+ # Tokenize
93
+ inputs_tokenized = self.tokenizer(
94
+ prompt,
95
+ return_tensors="pt",
96
+ padding=True,
97
+ truncation=True,
98
+ max_length=4096
99
+ ).to(self.model.device)
100
+
101
+ # Generate response
102
+ # Prepare additional decoding constraints
103
+ bad_words_ids = []
104
+ try:
105
+ # Disallow role-tag leakage in generations
106
+ role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
107
+ tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
108
+ # input_ids can be nested lists (one per tokenized string)
109
+ for ids in tokenized:
110
+ if isinstance(ids, list) and len(ids) > 0:
111
+ bad_words_ids.append(ids)
112
+ except Exception:
113
+ pass
114
+
115
+ decoding_kwargs = {
116
+ **generation_params,
117
+ # Encourage coherence and reduce repetition/artifacts
118
+ "no_repeat_ngram_size": 3,
119
+ }
120
+ if bad_words_ids:
121
+ decoding_kwargs["bad_words_ids"] = bad_words_ids
122
+
123
+ with torch.no_grad():
124
+ outputs = self.model.generate(
125
+ **inputs_tokenized,
126
+ **decoding_kwargs
127
+ )
128
+
129
+ # Decode the response
130
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
131
+
132
+ # Extract only the assistant's response and trim at stop sequences
133
+ response = self._extract_response(generated_text, prompt)
134
+ response = self._truncate_at_stops(response)
135
+
136
+ results.append({
137
+ "generated_text": response,
138
+ "conversation": {
139
+ "user": input_text,
140
+ "assistant": response
141
+ }
142
+ })
143
+
144
+ return results
145
+
146
+ def _format_prompt(self, user_input: str) -> str:
147
+ """
148
+ Format the user input into Mistral Instruct format
149
+
150
+ Args:
151
+ user_input: The user's message
152
+
153
+ Returns:
154
+ Formatted prompt string in Mistral format
155
+ """
156
+ # Mistral Instruct format: <s>[INST] system\n\nuser [/INST]
157
+ return f"<s>[INST] {self.system_prompt}\n\n{user_input} [/INST]"
158
+
159
+ def _extract_response(self, generated_text: str, prompt: str) -> str:
160
+ """
161
+ Extract only the assistant's response from the generated text
162
+
163
+ Args:
164
+ generated_text: Full generated text including prompt
165
+ prompt: The original prompt
166
+
167
+ Returns:
168
+ Just the assistant's response
169
+ """
170
+ # For Mistral format, response comes after [/INST]
171
+ if "[/INST]" in generated_text:
172
+ response = generated_text.split("[/INST]")[-1].strip()
173
+ elif generated_text.startswith(prompt):
174
+ response = generated_text[len(prompt):].strip()
175
+ else:
176
+ response = generated_text.strip()
177
+
178
+ # Remove any trailing </s> token
179
+ response = response.replace("</s>", "").strip()
180
+
181
+ # Ensure we have a response
182
+ if not response:
183
+ response = "I'm here to help! Could you please rephrase your question?"
184
+
185
+ return response
186
+
187
+ def _truncate_at_stops(self, text: str) -> str:
188
+ """Truncate model output at conversation stop markers."""
189
+ # Mistral stop markers
190
+ stop_markers = [
191
+ "\n[INST]", "[INST]", "</s>", "<s>",
192
+ "\nHuman:", "\nUser:", "\nAssistant:",
193
+ ]
194
+ cut_index = None
195
+ for marker in stop_markers:
196
+ idx = text.find(marker)
197
+ if idx != -1:
198
+ cut_index = idx if cut_index is None else min(cut_index, idx)
199
+ if cut_index is not None:
200
+ text = text[:cut_index].rstrip()
201
+ # Keep response reasonably bounded
202
+ if len(text) > 2000:
203
+ text = text[:2000].rstrip()
204
+ return text