Kiki-or-Bouba-classifier / classifier.py
jnalv's picture
Add explicit key extraction and debug logging for CLIP embeddings
36198df
"""
Kiki/Bouba Visual Classifier
Uses vision-language models (SigLIP/CLIP) to classify images as "Kiki" (angular/spiky)
or "Bouba" (rounded/soft) based on the cross-cultural sound-shape association phenomenon.
"""
import torch
import torch.nn.functional as F
from transformers import AutoProcessor, AutoModel, CLIPProcessor, CLIPModel
from PIL import Image
from typing import Dict, List, Union, Tuple, Optional
class KikiBoubaClassifier:
"""
Classifier that determines if an image is more "Kiki" (angular/spiky)
or "Bouba" (rounded/soft) using vision-language embeddings.
Uses expanded cross-modal anchors (~200 total) based on NeurIPS 2023 research:
"Kiki or Bouba? Sound Symbolism in Vision-and-Language Models"
https://arxiv.org/abs/2310.16781
Anchors span multiple sensory domains: shape, texture, taste, color/light,
sound, sensation, movement, emotion, and abstract qualities.
"""
# Kiki anchors organized by domain (Angular / Sharp / Intense)
KIKI_ANCHORS_BY_DOMAIN = {
# Shape & Geometry (Primary - highest confidence)
"shape_primary": [
"sharp", "spiky", "angular", "jagged", "pointed", "edgy", "geometric",
"crystalline", "fractured", "serrated", "zigzag", "triangular", "diagonal",
"hexagonal", "polygonal", "faceted", "prismatic", "chiseled", "carved", "etched",
],
# Texture (High confidence)
"texture": [
"rough", "coarse", "gritty", "scratchy", "abrasive", "bristly", "prickly",
"thorny", "barbed", "splintered", "grainy", "sandpapery", "rugged", "craggy", "uneven",
],
# Taste & Flavor (Cross-modal - validated)
"taste": [
"acidic", "sour", "bitter", "tart", "tangy", "astringent", "pungent",
"zesty", "biting", "acrid", "vinegary", "citrusy", "lemony", "sharp-tasting",
],
# Color & Light (Cross-modal)
"color_light": [
"bright", "vivid", "glaring", "fluorescent", "neon", "blinding", "harsh",
"saturated", "electric", "shocking", "stark", "contrasting", "yellow", "red", "white",
],
# Sound (Cross-modal - validated)
"sound": [
"high-pitched", "shrill", "piercing", "screeching", "staccato", "clashing",
"clanging", "crackling", "snapping", "clicking", "tinny", "metallic", "discordant", "jarring",
],
# Temperature & Sensation
"sensation": [
"cold", "icy", "freezing", "stinging", "burning", "prickling", "tingling",
"electric", "shocking", "intense",
],
# Movement & Speed
"movement": [
"fast", "quick", "rapid", "jerky", "abrupt", "sudden", "darting",
"twitchy", "erratic", "spasmodic", "snappy", "jagged-motion",
],
# Emotion & Energy (Cross-modal)
"emotion": [
"tense", "anxious", "nervous", "stressed", "agitated", "alert", "aggressive",
"hostile", "angry", "irritable", "fierce", "intense", "urgent", "frantic",
],
# Abstract Qualities
"abstract": [
"harsh", "hard", "rigid", "stiff", "brittle", "crisp", "precise",
"exact", "strict", "severe", "stern", "unforgiving", "dangerous", "threatening",
],
}
# Bouba anchors organized by domain (Rounded / Soft / Gentle)
BOUBA_ANCHORS_BY_DOMAIN = {
# Shape & Geometry (Primary - highest confidence)
"shape_primary": [
"round", "rounded", "circular", "curved", "bulbous", "spherical", "globular",
"oval", "elliptical", "undulating", "wavy", "flowing", "organic", "amorphous",
"blobby", "puffy", "billowy", "domed", "arched", "swooping",
],
# Texture (High confidence)
"texture": [
"soft", "smooth", "silky", "velvety", "plush", "fluffy", "fuzzy",
"downy", "cottony", "cushiony", "spongy", "supple", "tender", "delicate", "gentle",
],
# Taste & Flavor (Cross-modal - validated)
"taste": [
"sweet", "creamy", "mild", "mellow", "bland", "buttery", "rich",
"chocolatey", "caramel", "honeyed", "sugary", "milky", "vanilla", "smooth-tasting",
],
# Color & Light (Cross-modal)
"color_light": [
"dim", "muted", "pastel", "soft-lit", "diffuse", "hazy", "foggy",
"misty", "dusky", "twilight", "warm", "golden", "amber", "blue", "purple",
],
# Sound (Cross-modal - validated)
"sound": [
"low-pitched", "deep", "resonant", "humming", "droning", "murmuring",
"rumbling", "melodic", "flowing", "legato", "muffled", "soft-sounding", "gentle", "soothing",
],
# Temperature & Sensation
"sensation": [
"warm", "lukewarm", "cozy", "comfortable", "soothing", "relaxing",
"calming", "gentle", "caressing", "embracing",
],
# Movement & Speed
"movement": [
"slow", "gradual", "languid", "lazy", "drifting", "floating",
"gliding", "flowing", "swaying", "undulating", "graceful", "smooth-motion",
],
# Emotion & Energy (Cross-modal)
"emotion": [
"calm", "peaceful", "relaxed", "serene", "tranquil", "content", "happy",
"joyful", "gentle", "kind", "friendly", "welcoming", "safe", "comforting",
],
# Abstract Qualities
"abstract": [
"soft", "gentle", "flexible", "yielding", "malleable", "pliable", "forgiving",
"lenient", "easygoing", "laid-back", "nurturing", "maternal", "protective", "embracing",
],
}
# Default domain weights (can be overridden in __init__)
# With centered cosine similarity, weights are more balanced since generic attractors are removed
DEFAULT_DOMAIN_WEIGHTS = {
"shape_primary": 1.5, # Highest - most visually relevant
"texture": 1.3, # High - directly visual
"color_light": 0.6, # Reduced - often background-influenced
"taste": 0.9, # Medium - cross-modal but validated
"sound": 0.7, # Lower - less visually relevant
"sensation": 0.8, # Lower - abstract
"movement": 1.0, # Medium - can be visual
"emotion": 0.6, # Lower - very abstract
"abstract": 0.7, # Lower - very abstract
}
# Flattened anchor lists (computed from domains, deduplicated)
KIKI_ANCHORS = None # Will be computed in __init__
BOUBA_ANCHORS = None # Will be computed in __init__
@classmethod
def _flatten_anchors(cls, anchors_by_domain: dict) -> Tuple[List[str], List[str]]:
"""
Flatten domain-organized anchors into a deduplicated list.
Returns:
Tuple of (flattened_anchors_list, domain_labels_list) where domain_labels_list
maps each anchor to its domain.
"""
seen = set()
result = []
domain_labels = []
for domain, domain_anchors in anchors_by_domain.items():
for anchor in domain_anchors:
if anchor not in seen:
seen.add(anchor)
result.append(anchor)
domain_labels.append(domain)
return result, domain_labels
def __init__(self, model_id: str = "openai/clip-vit-large-patch14",
domain_weights: Optional[Dict[str, float]] = None):
"""
Initialize the classifier with a vision-language model.
Args:
model_id: HuggingFace model identifier (default: CLIP ViT-Large)
domain_weights: Optional dict mapping domain names to weights.
If None, uses DEFAULT_DOMAIN_WEIGHTS.
Weights control the influence of each domain on classification.
"""
# Always initialize on CPU for ZeroGPU compatibility
# ZeroGPU allocates GPU on-demand, and GPU context changes between requests
# ensure_device() will move to GPU when classify() is called
self.device = "cpu"
print(f"Initializing on device: {self.device}")
# Load model and processor
print(f"Loading model: {model_id}")
# Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP
if "clip" in model_id.lower():
try:
self.model = CLIPModel.from_pretrained(model_id)
self.processor = CLIPProcessor.from_pretrained(model_id)
print(f"Loaded CLIPModel - has get_text_features: {hasattr(self.model, 'get_text_features')}")
print(f"Loaded CLIPModel - has get_image_features: {hasattr(self.model, 'get_image_features')}")
except Exception as e:
print(f"Warning: Failed to load as CLIPModel, trying AutoModel: {e}")
self.model = AutoModel.from_pretrained(model_id)
self.processor = AutoProcessor.from_pretrained(model_id)
else:
self.model = AutoModel.from_pretrained(model_id)
self.processor = AutoProcessor.from_pretrained(model_id)
# Move model to device and set to evaluation mode
self.model.to(self.device)
self.model.eval()
# Set domain weights (use defaults if not provided)
if domain_weights is None:
self.domain_weights = self.DEFAULT_DOMAIN_WEIGHTS.copy()
else:
# Merge with defaults, allowing partial overrides
self.domain_weights = self.DEFAULT_DOMAIN_WEIGHTS.copy()
self.domain_weights.update(domain_weights)
# Flatten and deduplicate anchor lists from domain dictionaries
# Also track which domain each anchor belongs to
self.kiki_anchors, self.kiki_anchor_domains = self._flatten_anchors(self.KIKI_ANCHORS_BY_DOMAIN)
self.bouba_anchors, self.bouba_anchor_domains = self._flatten_anchors(self.BOUBA_ANCHORS_BY_DOMAIN)
print(f"Using {len(self.kiki_anchors)} Kiki anchors and {len(self.bouba_anchors)} Bouba anchors")
print(f"Anchor domains: {list(self.KIKI_ANCHORS_BY_DOMAIN.keys())}")
print(f"Domain weights: {self.domain_weights}")
# Pre-compute text anchor embeddings
print("Pre-computing text anchor embeddings...")
self.kiki_embeddings = self._embed_texts(self.kiki_anchors)
self.bouba_embeddings = self._embed_texts(self.bouba_anchors)
print("Classifier ready!")
def ensure_device(self, device: str):
"""
Ensure model and embeddings are on the specified device.
Called before inference to handle ZeroGPU device switching.
Args:
device: Target device ('cuda' or 'cpu')
"""
# Always move to current device - required for ZeroGPU where GPU context changes between requests
# Even if self.device == device, the tensors may be on a stale GPU context
if self.device != device:
print(f"Moving classifier from {self.device} to {device}")
self.device = device
self.model.to(device)
self.kiki_embeddings = self.kiki_embeddings.to(device)
self.bouba_embeddings = self.bouba_embeddings.to(device)
def _embed_texts(self, texts: List[str]) -> torch.Tensor:
"""
Encode text anchors into normalized embeddings.
Args:
texts: List of text strings to embed
Returns:
Normalized text embeddings tensor
"""
inputs = self.processor(text=texts, return_tensors="pt", padding=True)
# Debug: print what keys the processor returns
print(f"Text processor returned keys: {list(inputs.keys())}")
# Explicitly extract only the keys text_model needs
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
with torch.no_grad():
# Use text_model with only the required inputs
text_outputs = self.model.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
text_embeds = text_outputs.pooler_output
embeddings = self.model.text_projection(text_embeds)
return F.normalize(embeddings, dim=-1)
def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor:
"""
Encode an image into a normalized embedding.
Args:
image: PIL Image or path to image file
Returns:
Normalized image embedding tensor
"""
# Handle string paths
if isinstance(image, str):
image = Image.open(image)
inputs = self.processor(images=image, return_tensors="pt")
# Debug: print what keys the processor returns
print(f"Image processor returned keys: {list(inputs.keys())}")
# Explicitly extract only pixel_values
pixel_values = inputs['pixel_values'].to(self.device)
with torch.no_grad():
# Use vision_model with only the required input
vision_outputs = self.model.vision_model(
pixel_values=pixel_values,
return_dict=True
)
image_embeds = vision_outputs.pooler_output
embedding = self.model.visual_projection(image_embeds)
return F.normalize(embedding, dim=-1)
def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str],
top_k_per_domain: int = 3) -> Dict[str, float]:
"""
Compute per-domain scores by grouping similarities by domain and taking top-K mean.
Args:
similarities: Tensor of similarity scores for all anchors
anchor_domains: List of domain names, one per anchor (same length as similarities)
top_k_per_domain: Number of top anchors to use per domain (default: 3)
Returns:
Dictionary mapping domain names to their mean top-K scores
"""
domain_scores = {}
# Group similarities by domain
domain_groups = {}
for i, domain in enumerate(anchor_domains):
if domain not in domain_groups:
domain_groups[domain] = []
domain_groups[domain].append(i)
# Compute top-K mean for each domain
for domain, indices in domain_groups.items():
# Convert indices list to tensor for proper indexing
indices_tensor = torch.tensor(indices, device=similarities.device)
domain_sims = similarities[indices_tensor]
k = min(top_k_per_domain, len(domain_sims))
if k > 0:
domain_scores[domain] = domain_sims.topk(k=k).values.mean().item()
else:
domain_scores[domain] = 0.0
return domain_scores
def _apply_domain_weights(self, domain_scores: Dict[str, float]) -> float:
"""
Apply domain weights to domain scores and compute weighted average.
Args:
domain_scores: Dictionary mapping domain names to their scores
Returns:
Weighted average score across all domains
"""
weighted_sum = 0.0
weight_sum = 0.0
for domain, score in domain_scores.items():
weight = self.domain_weights.get(domain, 0.0)
if weight > 0: # Skip domains with zero or negative weight
weighted_sum += score * weight
weight_sum += weight
# If all weights are zero, fall back to equal weighting
if weight_sum == 0:
return sum(domain_scores.values()) / len(domain_scores) if domain_scores else 0.0
return weighted_sum / weight_sum
def classify(self, image: Union[Image.Image, str], top_k: int = 10) -> Dict:
"""
Classify an image as Kiki or Bouba using domain-weighted top-K scoring and difference-based spectrum.
Args:
image: PIL Image or path to image file
top_k: Approximate number of top-scoring anchors to use (default: 10)
This is divided across domains, with top-K per domain computed separately.
Domain weights are then applied to combine domain scores.
Returns:
Dictionary containing:
- kiki_score: Weighted average similarity to top-K Kiki anchors across domains
- bouba_score: Weighted average similarity to top-K Bouba anchors across domains
- spectrum_position: Position on 0-1 spectrum (0=Kiki, 1=Bouba) based on percentage difference
- classification: "Kiki", "Neutral", or "Bouba"
- confidence: Confidence score (0-1)
- kiki_domain_scores: Dict mapping domain names to their Kiki scores
- bouba_domain_scores: Dict mapping domain names to their Bouba scores
- kiki_anchor_scores: Dict of individual Kiki anchor similarities
- bouba_anchor_scores: Dict of individual Bouba anchor similarities
"""
# Ensure all components are on the current device
# This handles ZeroGPU device switching
current_device = "cuda" if torch.cuda.is_available() else "cpu"
self.ensure_device(current_device)
# Encode image
image_emb = self._embed_image(image)
# Calculate cosine similarities with anchor embeddings (standard approach)
# CLIP embeddings are normalized, so this computes cosine similarity
kiki_sims = (image_emb @ self.kiki_embeddings.T).squeeze()
bouba_sims = (image_emb @ self.bouba_embeddings.T).squeeze()
# Compute per-domain scores with top-K within each domain
# This ensures strong matches within each domain aren't diluted, and allows domain weighting
top_k_per_domain = max(1, top_k // 3) # Use top-K per domain (roughly 3-4 anchors per domain)
kiki_domain_scores = self._compute_domain_scores(
kiki_sims, self.kiki_anchor_domains, top_k_per_domain=top_k_per_domain
)
bouba_domain_scores = self._compute_domain_scores(
bouba_sims, self.bouba_anchor_domains, top_k_per_domain=top_k_per_domain
)
# Apply domain weights and compute weighted average
kiki_score = self._apply_domain_weights(kiki_domain_scores)
bouba_score = self._apply_domain_weights(bouba_domain_scores)
# Calculate percentage-based difference spectrum position (0 = pure Kiki, 1 = pure Bouba)
# This approach normalizes for score magnitude, making it sensitive to relative differences
# rather than absolute values, which adapts better to different similarity ranges
# Calculate average score for normalization
avg_score = (kiki_score + bouba_score) / 2
# Calculate percentage difference (normalizes for overall similarity magnitude)
if avg_score > 0:
percent_diff = (bouba_score - kiki_score) / avg_score
else:
percent_diff = 0.0 # Edge case: both scores are zero
# Scale factor for percentage differences (tuned for CLIP embeddings)
# CLIP typically produces higher absolute similarities than SigLIP
# A 25% difference maps to full spectrum (0.0 or 1.0)
scale_factor = 0.25
# Convert percentage difference to spectrum position: 0.5 (neutral) + scaled percent_diff
spectrum_position = 0.5 + (percent_diff / scale_factor)
# Clamp to [0, 1] range
spectrum_position = max(0.0, min(1.0, spectrum_position))
# Determine classification with neutral zone
# Kiki: 0.0-0.4, Neutral: 0.4-0.6, Bouba: 0.6-1.0
if spectrum_position < 0.4:
classification = "Kiki"
elif spectrum_position > 0.6:
classification = "Bouba"
else:
classification = "Neutral"
# Calculate confidence
# For Kiki/Bouba: distance from center (0.5), scaled to 0-1
# For Neutral: how close to center (0.5), scaled to 0-1
if classification == "Neutral":
# Confidence is how close to 0.5 (center of neutral zone)
confidence = 1.0 - abs(spectrum_position - 0.5) * 5 # Scale so 0.5 = 1.0, 0.4/0.6 = 0.5
confidence = max(0.0, min(1.0, confidence)) # Clamp to [0, 1]
else:
# Distance from center, scaled to 0-1
confidence = abs(spectrum_position - 0.5) * 2
# Create anchor score dictionaries
kiki_anchor_scores = dict(zip(self.kiki_anchors, kiki_sims.cpu().tolist()))
bouba_anchor_scores = dict(zip(self.bouba_anchors, bouba_sims.cpu().tolist()))
return {
"kiki_score": kiki_score,
"bouba_score": bouba_score,
"spectrum_position": spectrum_position, # 0=Kiki, 1=Bouba (percentage-based)
"classification": classification,
"confidence": confidence,
"kiki_domain_scores": kiki_domain_scores,
"bouba_domain_scores": bouba_domain_scores,
"kiki_anchor_scores": kiki_anchor_scores,
"bouba_anchor_scores": bouba_anchor_scores
}