jnalv commited on
Commit
eb8ac19
·
1 Parent(s): d547345

Improve CLIP tensor extraction from BaseModelOutputWithPooling and add better error handling

Browse files
Files changed (1) hide show
  1. classifier.py +83 -22
classifier.py CHANGED
@@ -182,8 +182,15 @@ class KikiBoubaClassifier:
182
  print(f"Loading model: {model_id}")
183
  # Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP
184
  if "clip" in model_id.lower():
185
- self.model = CLIPModel.from_pretrained(model_id)
186
- self.processor = CLIPProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
187
  else:
188
  self.model = AutoModel.from_pretrained(model_id)
189
  self.processor = AutoProcessor.from_pretrained(model_id)
@@ -229,26 +236,53 @@ class KikiBoubaClassifier:
229
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
230
 
231
  with torch.no_grad():
232
- # Handle both SigLIP (has get_text_features) and CLIP (uses different API)
233
  if hasattr(self.model, 'get_text_features'):
234
- embeddings = self.model.get_text_features(**inputs)
 
 
 
 
 
 
 
 
235
  elif hasattr(self.model, 'text_model'):
236
- # CLIP models: access text_model and get pooled output
237
  outputs = self.model.text_model(**inputs)
238
- # CLIP text_model returns BaseModelOutputWithPooling
239
  if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
240
  embeddings = outputs.pooler_output
241
- else:
242
- # Fallback: mean pool the last hidden state
243
  embeddings = outputs.last_hidden_state.mean(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  else:
245
- # Fallback: try calling model directly
246
  outputs = self.model(**inputs)
247
- embeddings = outputs.text_embeds if hasattr(outputs, 'text_embeds') else outputs[0]
 
 
 
 
 
248
 
249
- # Ensure embeddings is a tensor and normalize
250
  if not isinstance(embeddings, torch.Tensor):
251
- raise ValueError(f"Expected tensor, got {type(embeddings)}")
 
 
252
  return F.normalize(embeddings, dim=-1)
253
 
254
  def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor:
@@ -269,26 +303,53 @@ class KikiBoubaClassifier:
269
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
270
 
271
  with torch.no_grad():
272
- # Handle both SigLIP (has get_image_features) and CLIP (uses different API)
273
  if hasattr(self.model, 'get_image_features'):
274
- embedding = self.model.get_image_features(**inputs)
 
 
 
 
 
 
 
 
275
  elif hasattr(self.model, 'vision_model'):
276
- # CLIP models: access vision_model and get pooled output
277
  outputs = self.model.vision_model(**inputs)
278
- # CLIP vision_model returns BaseModelOutputWithPooling
279
  if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
280
  embedding = outputs.pooler_output
281
- else:
282
- # Fallback: mean pool the last hidden state
283
  embedding = outputs.last_hidden_state.mean(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  else:
285
- # Fallback: try calling model directly
286
  outputs = self.model(**inputs)
287
- embedding = outputs.image_embeds if hasattr(outputs, 'image_embeds') else outputs[0]
 
 
 
 
 
288
 
289
- # Ensure embedding is a tensor and normalize
290
  if not isinstance(embedding, torch.Tensor):
291
- raise ValueError(f"Expected tensor, got {type(embedding)}")
 
 
292
  return F.normalize(embedding, dim=-1)
293
 
294
  def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str],
 
182
  print(f"Loading model: {model_id}")
183
  # Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP
184
  if "clip" in model_id.lower():
185
+ try:
186
+ self.model = CLIPModel.from_pretrained(model_id)
187
+ self.processor = CLIPProcessor.from_pretrained(model_id)
188
+ print(f"Loaded CLIPModel - has get_text_features: {hasattr(self.model, 'get_text_features')}")
189
+ print(f"Loaded CLIPModel - has get_image_features: {hasattr(self.model, 'get_image_features')}")
190
+ except Exception as e:
191
+ print(f"Warning: Failed to load as CLIPModel, trying AutoModel: {e}")
192
+ self.model = AutoModel.from_pretrained(model_id)
193
+ self.processor = AutoProcessor.from_pretrained(model_id)
194
  else:
195
  self.model = AutoModel.from_pretrained(model_id)
196
  self.processor = AutoProcessor.from_pretrained(model_id)
 
236
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
237
 
238
  with torch.no_grad():
239
+ # CLIPModel has get_text_features method that returns tensor directly
240
  if hasattr(self.model, 'get_text_features'):
241
+ try:
242
+ embeddings = self.model.get_text_features(**inputs)
243
+ except Exception:
244
+ # Fallback: use text_model directly
245
+ outputs = self.model.text_model(**inputs)
246
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
247
+ embeddings = outputs.pooler_output
248
+ else:
249
+ embeddings = outputs.last_hidden_state.mean(dim=1)
250
  elif hasattr(self.model, 'text_model'):
251
+ # Direct access to text_model
252
  outputs = self.model.text_model(**inputs)
253
+ # Extract tensor from BaseModelOutputWithPooling
254
  if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
255
  embeddings = outputs.pooler_output
256
+ elif hasattr(outputs, 'last_hidden_state'):
 
257
  embeddings = outputs.last_hidden_state.mean(dim=1)
258
+ else:
259
+ # Try to get the first attribute that's a tensor
260
+ for attr in ['pooler_output', 'last_hidden_state', 'hidden_states']:
261
+ if hasattr(outputs, attr):
262
+ val = getattr(outputs, attr)
263
+ if isinstance(val, torch.Tensor):
264
+ if len(val.shape) > 2:
265
+ embeddings = val.mean(dim=1)
266
+ else:
267
+ embeddings = val
268
+ break
269
+ else:
270
+ raise ValueError(f"Could not extract tensor from text_model output: {type(outputs)}, attributes: {dir(outputs)}")
271
  else:
272
+ # Final fallback: use model forward pass
273
  outputs = self.model(**inputs)
274
+ if hasattr(outputs, 'text_embeds'):
275
+ embeddings = outputs.text_embeds
276
+ elif isinstance(outputs, tuple) and len(outputs) > 0:
277
+ embeddings = outputs[0]
278
+ else:
279
+ raise ValueError(f"Could not extract text embeddings from model output: {type(outputs)}")
280
 
281
+ # Ensure embeddings is a tensor
282
  if not isinstance(embeddings, torch.Tensor):
283
+ raise ValueError(f"Expected tensor, got {type(embeddings)}: {embeddings}")
284
+
285
+ # Normalize embeddings
286
  return F.normalize(embeddings, dim=-1)
287
 
288
  def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor:
 
303
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
304
 
305
  with torch.no_grad():
306
+ # CLIPModel has get_image_features method that returns tensor directly
307
  if hasattr(self.model, 'get_image_features'):
308
+ try:
309
+ embedding = self.model.get_image_features(**inputs)
310
+ except Exception:
311
+ # Fallback: use vision_model directly
312
+ outputs = self.model.vision_model(**inputs)
313
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
314
+ embedding = outputs.pooler_output
315
+ else:
316
+ embedding = outputs.last_hidden_state.mean(dim=1)
317
  elif hasattr(self.model, 'vision_model'):
318
+ # Direct access to vision_model
319
  outputs = self.model.vision_model(**inputs)
320
+ # Extract tensor from BaseModelOutputWithPooling
321
  if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
322
  embedding = outputs.pooler_output
323
+ elif hasattr(outputs, 'last_hidden_state'):
 
324
  embedding = outputs.last_hidden_state.mean(dim=1)
325
+ else:
326
+ # Try to get the first attribute that's a tensor
327
+ for attr in ['pooler_output', 'last_hidden_state', 'hidden_states']:
328
+ if hasattr(outputs, attr):
329
+ val = getattr(outputs, attr)
330
+ if isinstance(val, torch.Tensor):
331
+ if len(val.shape) > 2:
332
+ embedding = val.mean(dim=1)
333
+ else:
334
+ embedding = val
335
+ break
336
+ else:
337
+ raise ValueError(f"Could not extract tensor from vision_model output: {type(outputs)}, attributes: {dir(outputs)}")
338
  else:
339
+ # Final fallback: use model forward pass
340
  outputs = self.model(**inputs)
341
+ if hasattr(outputs, 'image_embeds'):
342
+ embedding = outputs.image_embeds
343
+ elif isinstance(outputs, tuple) and len(outputs) > 0:
344
+ embedding = outputs[0]
345
+ else:
346
+ raise ValueError(f"Could not extract image embeddings from model output: {type(outputs)}")
347
 
348
+ # Ensure embedding is a tensor
349
  if not isinstance(embedding, torch.Tensor):
350
+ raise ValueError(f"Expected tensor, got {type(embedding)}: {embedding}")
351
+
352
+ # Normalize embedding
353
  return F.normalize(embedding, dim=-1)
354
 
355
  def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str],