""" Kiki/Bouba Visual Classifier Uses vision-language models (SigLIP/CLIP) to classify images as "Kiki" (angular/spiky) or "Bouba" (rounded/soft) based on the cross-cultural sound-shape association phenomenon. """ import torch import torch.nn.functional as F from transformers import AutoProcessor, AutoModel, CLIPProcessor, CLIPModel from PIL import Image from typing import Dict, List, Union, Tuple, Optional class KikiBoubaClassifier: """ Classifier that determines if an image is more "Kiki" (angular/spiky) or "Bouba" (rounded/soft) using vision-language embeddings. Uses expanded cross-modal anchors (~200 total) based on NeurIPS 2023 research: "Kiki or Bouba? Sound Symbolism in Vision-and-Language Models" https://arxiv.org/abs/2310.16781 Anchors span multiple sensory domains: shape, texture, taste, color/light, sound, sensation, movement, emotion, and abstract qualities. """ # Kiki anchors organized by domain (Angular / Sharp / Intense) KIKI_ANCHORS_BY_DOMAIN = { # Shape & Geometry (Primary - highest confidence) "shape_primary": [ "sharp", "spiky", "angular", "jagged", "pointed", "edgy", "geometric", "crystalline", "fractured", "serrated", "zigzag", "triangular", "diagonal", "hexagonal", "polygonal", "faceted", "prismatic", "chiseled", "carved", "etched", ], # Texture (High confidence) "texture": [ "rough", "coarse", "gritty", "scratchy", "abrasive", "bristly", "prickly", "thorny", "barbed", "splintered", "grainy", "sandpapery", "rugged", "craggy", "uneven", ], # Taste & Flavor (Cross-modal - validated) "taste": [ "acidic", "sour", "bitter", "tart", "tangy", "astringent", "pungent", "zesty", "biting", "acrid", "vinegary", "citrusy", "lemony", "sharp-tasting", ], # Color & Light (Cross-modal) "color_light": [ "bright", "vivid", "glaring", "fluorescent", "neon", "blinding", "harsh", "saturated", "electric", "shocking", "stark", "contrasting", "yellow", "red", "white", ], # Sound (Cross-modal - validated) "sound": [ "high-pitched", "shrill", "piercing", "screeching", "staccato", "clashing", "clanging", "crackling", "snapping", "clicking", "tinny", "metallic", "discordant", "jarring", ], # Temperature & Sensation "sensation": [ "cold", "icy", "freezing", "stinging", "burning", "prickling", "tingling", "electric", "shocking", "intense", ], # Movement & Speed "movement": [ "fast", "quick", "rapid", "jerky", "abrupt", "sudden", "darting", "twitchy", "erratic", "spasmodic", "snappy", "jagged-motion", ], # Emotion & Energy (Cross-modal) "emotion": [ "tense", "anxious", "nervous", "stressed", "agitated", "alert", "aggressive", "hostile", "angry", "irritable", "fierce", "intense", "urgent", "frantic", ], # Abstract Qualities "abstract": [ "harsh", "hard", "rigid", "stiff", "brittle", "crisp", "precise", "exact", "strict", "severe", "stern", "unforgiving", "dangerous", "threatening", ], } # Bouba anchors organized by domain (Rounded / Soft / Gentle) BOUBA_ANCHORS_BY_DOMAIN = { # Shape & Geometry (Primary - highest confidence) "shape_primary": [ "round", "rounded", "circular", "curved", "bulbous", "spherical", "globular", "oval", "elliptical", "undulating", "wavy", "flowing", "organic", "amorphous", "blobby", "puffy", "billowy", "domed", "arched", "swooping", ], # Texture (High confidence) "texture": [ "soft", "smooth", "silky", "velvety", "plush", "fluffy", "fuzzy", "downy", "cottony", "cushiony", "spongy", "supple", "tender", "delicate", "gentle", ], # Taste & Flavor (Cross-modal - validated) "taste": [ "sweet", "creamy", "mild", "mellow", "bland", "buttery", "rich", "chocolatey", "caramel", "honeyed", "sugary", "milky", "vanilla", "smooth-tasting", ], # Color & Light (Cross-modal) "color_light": [ "dim", "muted", "pastel", "soft-lit", "diffuse", "hazy", "foggy", "misty", "dusky", "twilight", "warm", "golden", "amber", "blue", "purple", ], # Sound (Cross-modal - validated) "sound": [ "low-pitched", "deep", "resonant", "humming", "droning", "murmuring", "rumbling", "melodic", "flowing", "legato", "muffled", "soft-sounding", "gentle", "soothing", ], # Temperature & Sensation "sensation": [ "warm", "lukewarm", "cozy", "comfortable", "soothing", "relaxing", "calming", "gentle", "caressing", "embracing", ], # Movement & Speed "movement": [ "slow", "gradual", "languid", "lazy", "drifting", "floating", "gliding", "flowing", "swaying", "undulating", "graceful", "smooth-motion", ], # Emotion & Energy (Cross-modal) "emotion": [ "calm", "peaceful", "relaxed", "serene", "tranquil", "content", "happy", "joyful", "gentle", "kind", "friendly", "welcoming", "safe", "comforting", ], # Abstract Qualities "abstract": [ "soft", "gentle", "flexible", "yielding", "malleable", "pliable", "forgiving", "lenient", "easygoing", "laid-back", "nurturing", "maternal", "protective", "embracing", ], } # Default domain weights (can be overridden in __init__) # With centered cosine similarity, weights are more balanced since generic attractors are removed DEFAULT_DOMAIN_WEIGHTS = { "shape_primary": 1.5, # Highest - most visually relevant "texture": 1.3, # High - directly visual "color_light": 0.6, # Reduced - often background-influenced "taste": 0.9, # Medium - cross-modal but validated "sound": 0.7, # Lower - less visually relevant "sensation": 0.8, # Lower - abstract "movement": 1.0, # Medium - can be visual "emotion": 0.6, # Lower - very abstract "abstract": 0.7, # Lower - very abstract } # Flattened anchor lists (computed from domains, deduplicated) KIKI_ANCHORS = None # Will be computed in __init__ BOUBA_ANCHORS = None # Will be computed in __init__ @classmethod def _flatten_anchors(cls, anchors_by_domain: dict) -> Tuple[List[str], List[str]]: """ Flatten domain-organized anchors into a deduplicated list. Returns: Tuple of (flattened_anchors_list, domain_labels_list) where domain_labels_list maps each anchor to its domain. """ seen = set() result = [] domain_labels = [] for domain, domain_anchors in anchors_by_domain.items(): for anchor in domain_anchors: if anchor not in seen: seen.add(anchor) result.append(anchor) domain_labels.append(domain) return result, domain_labels def __init__(self, model_id: str = "openai/clip-vit-large-patch14", domain_weights: Optional[Dict[str, float]] = None): """ Initialize the classifier with a vision-language model. Args: model_id: HuggingFace model identifier (default: CLIP ViT-Large) domain_weights: Optional dict mapping domain names to weights. If None, uses DEFAULT_DOMAIN_WEIGHTS. Weights control the influence of each domain on classification. """ # Always initialize on CPU for ZeroGPU compatibility # ZeroGPU allocates GPU on-demand, and GPU context changes between requests # ensure_device() will move to GPU when classify() is called self.device = "cpu" print(f"Initializing on device: {self.device}") # Load model and processor print(f"Loading model: {model_id}") # Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP if "clip" in model_id.lower(): try: self.model = CLIPModel.from_pretrained(model_id) self.processor = CLIPProcessor.from_pretrained(model_id) print(f"Loaded CLIPModel - has get_text_features: {hasattr(self.model, 'get_text_features')}") print(f"Loaded CLIPModel - has get_image_features: {hasattr(self.model, 'get_image_features')}") except Exception as e: print(f"Warning: Failed to load as CLIPModel, trying AutoModel: {e}") self.model = AutoModel.from_pretrained(model_id) self.processor = AutoProcessor.from_pretrained(model_id) else: self.model = AutoModel.from_pretrained(model_id) self.processor = AutoProcessor.from_pretrained(model_id) # Move model to device and set to evaluation mode self.model.to(self.device) self.model.eval() # Set domain weights (use defaults if not provided) if domain_weights is None: self.domain_weights = self.DEFAULT_DOMAIN_WEIGHTS.copy() else: # Merge with defaults, allowing partial overrides self.domain_weights = self.DEFAULT_DOMAIN_WEIGHTS.copy() self.domain_weights.update(domain_weights) # Flatten and deduplicate anchor lists from domain dictionaries # Also track which domain each anchor belongs to self.kiki_anchors, self.kiki_anchor_domains = self._flatten_anchors(self.KIKI_ANCHORS_BY_DOMAIN) self.bouba_anchors, self.bouba_anchor_domains = self._flatten_anchors(self.BOUBA_ANCHORS_BY_DOMAIN) print(f"Using {len(self.kiki_anchors)} Kiki anchors and {len(self.bouba_anchors)} Bouba anchors") print(f"Anchor domains: {list(self.KIKI_ANCHORS_BY_DOMAIN.keys())}") print(f"Domain weights: {self.domain_weights}") # Pre-compute text anchor embeddings print("Pre-computing text anchor embeddings...") self.kiki_embeddings = self._embed_texts(self.kiki_anchors) self.bouba_embeddings = self._embed_texts(self.bouba_anchors) print("Classifier ready!") def ensure_device(self, device: str): """ Ensure model and embeddings are on the specified device. Called before inference to handle ZeroGPU device switching. Args: device: Target device ('cuda' or 'cpu') """ # Always move to current device - required for ZeroGPU where GPU context changes between requests # Even if self.device == device, the tensors may be on a stale GPU context if self.device != device: print(f"Moving classifier from {self.device} to {device}") self.device = device self.model.to(device) self.kiki_embeddings = self.kiki_embeddings.to(device) self.bouba_embeddings = self.bouba_embeddings.to(device) def _embed_texts(self, texts: List[str]) -> torch.Tensor: """ Encode text anchors into normalized embeddings. Args: texts: List of text strings to embed Returns: Normalized text embeddings tensor """ inputs = self.processor(text=texts, return_tensors="pt", padding=True) # Debug: print what keys the processor returns print(f"Text processor returned keys: {list(inputs.keys())}") # Explicitly extract only the keys text_model needs input_ids = inputs['input_ids'].to(self.device) attention_mask = inputs['attention_mask'].to(self.device) with torch.no_grad(): # Use text_model with only the required inputs text_outputs = self.model.text_model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) text_embeds = text_outputs.pooler_output embeddings = self.model.text_projection(text_embeds) return F.normalize(embeddings, dim=-1) def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor: """ Encode an image into a normalized embedding. Args: image: PIL Image or path to image file Returns: Normalized image embedding tensor """ # Handle string paths if isinstance(image, str): image = Image.open(image) inputs = self.processor(images=image, return_tensors="pt") # Debug: print what keys the processor returns print(f"Image processor returned keys: {list(inputs.keys())}") # Explicitly extract only pixel_values pixel_values = inputs['pixel_values'].to(self.device) with torch.no_grad(): # Use vision_model with only the required input vision_outputs = self.model.vision_model( pixel_values=pixel_values, return_dict=True ) image_embeds = vision_outputs.pooler_output embedding = self.model.visual_projection(image_embeds) return F.normalize(embedding, dim=-1) def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str], top_k_per_domain: int = 3) -> Dict[str, float]: """ Compute per-domain scores by grouping similarities by domain and taking top-K mean. Args: similarities: Tensor of similarity scores for all anchors anchor_domains: List of domain names, one per anchor (same length as similarities) top_k_per_domain: Number of top anchors to use per domain (default: 3) Returns: Dictionary mapping domain names to their mean top-K scores """ domain_scores = {} # Group similarities by domain domain_groups = {} for i, domain in enumerate(anchor_domains): if domain not in domain_groups: domain_groups[domain] = [] domain_groups[domain].append(i) # Compute top-K mean for each domain for domain, indices in domain_groups.items(): # Convert indices list to tensor for proper indexing indices_tensor = torch.tensor(indices, device=similarities.device) domain_sims = similarities[indices_tensor] k = min(top_k_per_domain, len(domain_sims)) if k > 0: domain_scores[domain] = domain_sims.topk(k=k).values.mean().item() else: domain_scores[domain] = 0.0 return domain_scores def _apply_domain_weights(self, domain_scores: Dict[str, float]) -> float: """ Apply domain weights to domain scores and compute weighted average. Args: domain_scores: Dictionary mapping domain names to their scores Returns: Weighted average score across all domains """ weighted_sum = 0.0 weight_sum = 0.0 for domain, score in domain_scores.items(): weight = self.domain_weights.get(domain, 0.0) if weight > 0: # Skip domains with zero or negative weight weighted_sum += score * weight weight_sum += weight # If all weights are zero, fall back to equal weighting if weight_sum == 0: return sum(domain_scores.values()) / len(domain_scores) if domain_scores else 0.0 return weighted_sum / weight_sum def classify(self, image: Union[Image.Image, str], top_k: int = 10) -> Dict: """ Classify an image as Kiki or Bouba using domain-weighted top-K scoring and difference-based spectrum. Args: image: PIL Image or path to image file top_k: Approximate number of top-scoring anchors to use (default: 10) This is divided across domains, with top-K per domain computed separately. Domain weights are then applied to combine domain scores. Returns: Dictionary containing: - kiki_score: Weighted average similarity to top-K Kiki anchors across domains - bouba_score: Weighted average similarity to top-K Bouba anchors across domains - spectrum_position: Position on 0-1 spectrum (0=Kiki, 1=Bouba) based on percentage difference - classification: "Kiki", "Neutral", or "Bouba" - confidence: Confidence score (0-1) - kiki_domain_scores: Dict mapping domain names to their Kiki scores - bouba_domain_scores: Dict mapping domain names to their Bouba scores - kiki_anchor_scores: Dict of individual Kiki anchor similarities - bouba_anchor_scores: Dict of individual Bouba anchor similarities """ # Ensure all components are on the current device # This handles ZeroGPU device switching current_device = "cuda" if torch.cuda.is_available() else "cpu" self.ensure_device(current_device) # Encode image image_emb = self._embed_image(image) # Calculate cosine similarities with anchor embeddings (standard approach) # CLIP embeddings are normalized, so this computes cosine similarity kiki_sims = (image_emb @ self.kiki_embeddings.T).squeeze() bouba_sims = (image_emb @ self.bouba_embeddings.T).squeeze() # Compute per-domain scores with top-K within each domain # This ensures strong matches within each domain aren't diluted, and allows domain weighting top_k_per_domain = max(1, top_k // 3) # Use top-K per domain (roughly 3-4 anchors per domain) kiki_domain_scores = self._compute_domain_scores( kiki_sims, self.kiki_anchor_domains, top_k_per_domain=top_k_per_domain ) bouba_domain_scores = self._compute_domain_scores( bouba_sims, self.bouba_anchor_domains, top_k_per_domain=top_k_per_domain ) # Apply domain weights and compute weighted average kiki_score = self._apply_domain_weights(kiki_domain_scores) bouba_score = self._apply_domain_weights(bouba_domain_scores) # Calculate percentage-based difference spectrum position (0 = pure Kiki, 1 = pure Bouba) # This approach normalizes for score magnitude, making it sensitive to relative differences # rather than absolute values, which adapts better to different similarity ranges # Calculate average score for normalization avg_score = (kiki_score + bouba_score) / 2 # Calculate percentage difference (normalizes for overall similarity magnitude) if avg_score > 0: percent_diff = (bouba_score - kiki_score) / avg_score else: percent_diff = 0.0 # Edge case: both scores are zero # Scale factor for percentage differences (tuned for CLIP embeddings) # CLIP typically produces higher absolute similarities than SigLIP # A 25% difference maps to full spectrum (0.0 or 1.0) scale_factor = 0.25 # Convert percentage difference to spectrum position: 0.5 (neutral) + scaled percent_diff spectrum_position = 0.5 + (percent_diff / scale_factor) # Clamp to [0, 1] range spectrum_position = max(0.0, min(1.0, spectrum_position)) # Determine classification with neutral zone # Kiki: 0.0-0.4, Neutral: 0.4-0.6, Bouba: 0.6-1.0 if spectrum_position < 0.4: classification = "Kiki" elif spectrum_position > 0.6: classification = "Bouba" else: classification = "Neutral" # Calculate confidence # For Kiki/Bouba: distance from center (0.5), scaled to 0-1 # For Neutral: how close to center (0.5), scaled to 0-1 if classification == "Neutral": # Confidence is how close to 0.5 (center of neutral zone) confidence = 1.0 - abs(spectrum_position - 0.5) * 5 # Scale so 0.5 = 1.0, 0.4/0.6 = 0.5 confidence = max(0.0, min(1.0, confidence)) # Clamp to [0, 1] else: # Distance from center, scaled to 0-1 confidence = abs(spectrum_position - 0.5) * 2 # Create anchor score dictionaries kiki_anchor_scores = dict(zip(self.kiki_anchors, kiki_sims.cpu().tolist())) bouba_anchor_scores = dict(zip(self.bouba_anchors, bouba_sims.cpu().tolist())) return { "kiki_score": kiki_score, "bouba_score": bouba_score, "spectrum_position": spectrum_position, # 0=Kiki, 1=Bouba (percentage-based) "classification": classification, "confidence": confidence, "kiki_domain_scores": kiki_domain_scores, "bouba_domain_scores": bouba_domain_scores, "kiki_anchor_scores": kiki_anchor_scores, "bouba_anchor_scores": bouba_anchor_scores }