Spaces:
Sleeping
Sleeping
Improve CLIP tensor extraction from BaseModelOutputWithPooling and add better error handling
Browse files- classifier.py +83 -22
classifier.py
CHANGED
|
@@ -182,8 +182,15 @@ class KikiBoubaClassifier:
|
|
| 182 |
print(f"Loading model: {model_id}")
|
| 183 |
# Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP
|
| 184 |
if "clip" in model_id.lower():
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
else:
|
| 188 |
self.model = AutoModel.from_pretrained(model_id)
|
| 189 |
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
@@ -229,26 +236,53 @@ class KikiBoubaClassifier:
|
|
| 229 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 230 |
|
| 231 |
with torch.no_grad():
|
| 232 |
-
#
|
| 233 |
if hasattr(self.model, 'get_text_features'):
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
elif hasattr(self.model, 'text_model'):
|
| 236 |
-
#
|
| 237 |
outputs = self.model.text_model(**inputs)
|
| 238 |
-
#
|
| 239 |
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 240 |
embeddings = outputs.pooler_output
|
| 241 |
-
|
| 242 |
-
# Fallback: mean pool the last hidden state
|
| 243 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
else:
|
| 245 |
-
#
|
| 246 |
outputs = self.model(**inputs)
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
# Ensure embeddings is a tensor
|
| 250 |
if not isinstance(embeddings, torch.Tensor):
|
| 251 |
-
raise ValueError(f"Expected tensor, got {type(embeddings)}")
|
|
|
|
|
|
|
| 252 |
return F.normalize(embeddings, dim=-1)
|
| 253 |
|
| 254 |
def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor:
|
|
@@ -269,26 +303,53 @@ class KikiBoubaClassifier:
|
|
| 269 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 270 |
|
| 271 |
with torch.no_grad():
|
| 272 |
-
#
|
| 273 |
if hasattr(self.model, 'get_image_features'):
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
elif hasattr(self.model, 'vision_model'):
|
| 276 |
-
#
|
| 277 |
outputs = self.model.vision_model(**inputs)
|
| 278 |
-
#
|
| 279 |
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 280 |
embedding = outputs.pooler_output
|
| 281 |
-
|
| 282 |
-
# Fallback: mean pool the last hidden state
|
| 283 |
embedding = outputs.last_hidden_state.mean(dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
else:
|
| 285 |
-
#
|
| 286 |
outputs = self.model(**inputs)
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
-
# Ensure embedding is a tensor
|
| 290 |
if not isinstance(embedding, torch.Tensor):
|
| 291 |
-
raise ValueError(f"Expected tensor, got {type(embedding)}")
|
|
|
|
|
|
|
| 292 |
return F.normalize(embedding, dim=-1)
|
| 293 |
|
| 294 |
def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str],
|
|
|
|
| 182 |
print(f"Loading model: {model_id}")
|
| 183 |
# Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP
|
| 184 |
if "clip" in model_id.lower():
|
| 185 |
+
try:
|
| 186 |
+
self.model = CLIPModel.from_pretrained(model_id)
|
| 187 |
+
self.processor = CLIPProcessor.from_pretrained(model_id)
|
| 188 |
+
print(f"Loaded CLIPModel - has get_text_features: {hasattr(self.model, 'get_text_features')}")
|
| 189 |
+
print(f"Loaded CLIPModel - has get_image_features: {hasattr(self.model, 'get_image_features')}")
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Warning: Failed to load as CLIPModel, trying AutoModel: {e}")
|
| 192 |
+
self.model = AutoModel.from_pretrained(model_id)
|
| 193 |
+
self.processor = AutoProcessor.from_pretrained(model_id)
|
| 194 |
else:
|
| 195 |
self.model = AutoModel.from_pretrained(model_id)
|
| 196 |
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
| 236 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 237 |
|
| 238 |
with torch.no_grad():
|
| 239 |
+
# CLIPModel has get_text_features method that returns tensor directly
|
| 240 |
if hasattr(self.model, 'get_text_features'):
|
| 241 |
+
try:
|
| 242 |
+
embeddings = self.model.get_text_features(**inputs)
|
| 243 |
+
except Exception:
|
| 244 |
+
# Fallback: use text_model directly
|
| 245 |
+
outputs = self.model.text_model(**inputs)
|
| 246 |
+
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 247 |
+
embeddings = outputs.pooler_output
|
| 248 |
+
else:
|
| 249 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 250 |
elif hasattr(self.model, 'text_model'):
|
| 251 |
+
# Direct access to text_model
|
| 252 |
outputs = self.model.text_model(**inputs)
|
| 253 |
+
# Extract tensor from BaseModelOutputWithPooling
|
| 254 |
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 255 |
embeddings = outputs.pooler_output
|
| 256 |
+
elif hasattr(outputs, 'last_hidden_state'):
|
|
|
|
| 257 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 258 |
+
else:
|
| 259 |
+
# Try to get the first attribute that's a tensor
|
| 260 |
+
for attr in ['pooler_output', 'last_hidden_state', 'hidden_states']:
|
| 261 |
+
if hasattr(outputs, attr):
|
| 262 |
+
val = getattr(outputs, attr)
|
| 263 |
+
if isinstance(val, torch.Tensor):
|
| 264 |
+
if len(val.shape) > 2:
|
| 265 |
+
embeddings = val.mean(dim=1)
|
| 266 |
+
else:
|
| 267 |
+
embeddings = val
|
| 268 |
+
break
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError(f"Could not extract tensor from text_model output: {type(outputs)}, attributes: {dir(outputs)}")
|
| 271 |
else:
|
| 272 |
+
# Final fallback: use model forward pass
|
| 273 |
outputs = self.model(**inputs)
|
| 274 |
+
if hasattr(outputs, 'text_embeds'):
|
| 275 |
+
embeddings = outputs.text_embeds
|
| 276 |
+
elif isinstance(outputs, tuple) and len(outputs) > 0:
|
| 277 |
+
embeddings = outputs[0]
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(f"Could not extract text embeddings from model output: {type(outputs)}")
|
| 280 |
|
| 281 |
+
# Ensure embeddings is a tensor
|
| 282 |
if not isinstance(embeddings, torch.Tensor):
|
| 283 |
+
raise ValueError(f"Expected tensor, got {type(embeddings)}: {embeddings}")
|
| 284 |
+
|
| 285 |
+
# Normalize embeddings
|
| 286 |
return F.normalize(embeddings, dim=-1)
|
| 287 |
|
| 288 |
def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor:
|
|
|
|
| 303 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 304 |
|
| 305 |
with torch.no_grad():
|
| 306 |
+
# CLIPModel has get_image_features method that returns tensor directly
|
| 307 |
if hasattr(self.model, 'get_image_features'):
|
| 308 |
+
try:
|
| 309 |
+
embedding = self.model.get_image_features(**inputs)
|
| 310 |
+
except Exception:
|
| 311 |
+
# Fallback: use vision_model directly
|
| 312 |
+
outputs = self.model.vision_model(**inputs)
|
| 313 |
+
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 314 |
+
embedding = outputs.pooler_output
|
| 315 |
+
else:
|
| 316 |
+
embedding = outputs.last_hidden_state.mean(dim=1)
|
| 317 |
elif hasattr(self.model, 'vision_model'):
|
| 318 |
+
# Direct access to vision_model
|
| 319 |
outputs = self.model.vision_model(**inputs)
|
| 320 |
+
# Extract tensor from BaseModelOutputWithPooling
|
| 321 |
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 322 |
embedding = outputs.pooler_output
|
| 323 |
+
elif hasattr(outputs, 'last_hidden_state'):
|
|
|
|
| 324 |
embedding = outputs.last_hidden_state.mean(dim=1)
|
| 325 |
+
else:
|
| 326 |
+
# Try to get the first attribute that's a tensor
|
| 327 |
+
for attr in ['pooler_output', 'last_hidden_state', 'hidden_states']:
|
| 328 |
+
if hasattr(outputs, attr):
|
| 329 |
+
val = getattr(outputs, attr)
|
| 330 |
+
if isinstance(val, torch.Tensor):
|
| 331 |
+
if len(val.shape) > 2:
|
| 332 |
+
embedding = val.mean(dim=1)
|
| 333 |
+
else:
|
| 334 |
+
embedding = val
|
| 335 |
+
break
|
| 336 |
+
else:
|
| 337 |
+
raise ValueError(f"Could not extract tensor from vision_model output: {type(outputs)}, attributes: {dir(outputs)}")
|
| 338 |
else:
|
| 339 |
+
# Final fallback: use model forward pass
|
| 340 |
outputs = self.model(**inputs)
|
| 341 |
+
if hasattr(outputs, 'image_embeds'):
|
| 342 |
+
embedding = outputs.image_embeds
|
| 343 |
+
elif isinstance(outputs, tuple) and len(outputs) > 0:
|
| 344 |
+
embedding = outputs[0]
|
| 345 |
+
else:
|
| 346 |
+
raise ValueError(f"Could not extract image embeddings from model output: {type(outputs)}")
|
| 347 |
|
| 348 |
+
# Ensure embedding is a tensor
|
| 349 |
if not isinstance(embedding, torch.Tensor):
|
| 350 |
+
raise ValueError(f"Expected tensor, got {type(embedding)}: {embedding}")
|
| 351 |
+
|
| 352 |
+
# Normalize embedding
|
| 353 |
return F.normalize(embedding, dim=-1)
|
| 354 |
|
| 355 |
def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str],
|