Spaces:
Sleeping
Sleeping
| """ | |
| Kiki/Bouba Visual Classifier | |
| Uses vision-language models (SigLIP/CLIP) to classify images as "Kiki" (angular/spiky) | |
| or "Bouba" (rounded/soft) based on the cross-cultural sound-shape association phenomenon. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoProcessor, AutoModel, CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| from typing import Dict, List, Union, Tuple, Optional | |
| class KikiBoubaClassifier: | |
| """ | |
| Classifier that determines if an image is more "Kiki" (angular/spiky) | |
| or "Bouba" (rounded/soft) using vision-language embeddings. | |
| Uses expanded cross-modal anchors (~200 total) based on NeurIPS 2023 research: | |
| "Kiki or Bouba? Sound Symbolism in Vision-and-Language Models" | |
| https://arxiv.org/abs/2310.16781 | |
| Anchors span multiple sensory domains: shape, texture, taste, color/light, | |
| sound, sensation, movement, emotion, and abstract qualities. | |
| """ | |
| # Kiki anchors organized by domain (Angular / Sharp / Intense) | |
| KIKI_ANCHORS_BY_DOMAIN = { | |
| # Shape & Geometry (Primary - highest confidence) | |
| "shape_primary": [ | |
| "sharp", "spiky", "angular", "jagged", "pointed", "edgy", "geometric", | |
| "crystalline", "fractured", "serrated", "zigzag", "triangular", "diagonal", | |
| "hexagonal", "polygonal", "faceted", "prismatic", "chiseled", "carved", "etched", | |
| ], | |
| # Texture (High confidence) | |
| "texture": [ | |
| "rough", "coarse", "gritty", "scratchy", "abrasive", "bristly", "prickly", | |
| "thorny", "barbed", "splintered", "grainy", "sandpapery", "rugged", "craggy", "uneven", | |
| ], | |
| # Taste & Flavor (Cross-modal - validated) | |
| "taste": [ | |
| "acidic", "sour", "bitter", "tart", "tangy", "astringent", "pungent", | |
| "zesty", "biting", "acrid", "vinegary", "citrusy", "lemony", "sharp-tasting", | |
| ], | |
| # Color & Light (Cross-modal) | |
| "color_light": [ | |
| "bright", "vivid", "glaring", "fluorescent", "neon", "blinding", "harsh", | |
| "saturated", "electric", "shocking", "stark", "contrasting", "yellow", "red", "white", | |
| ], | |
| # Sound (Cross-modal - validated) | |
| "sound": [ | |
| "high-pitched", "shrill", "piercing", "screeching", "staccato", "clashing", | |
| "clanging", "crackling", "snapping", "clicking", "tinny", "metallic", "discordant", "jarring", | |
| ], | |
| # Temperature & Sensation | |
| "sensation": [ | |
| "cold", "icy", "freezing", "stinging", "burning", "prickling", "tingling", | |
| "electric", "shocking", "intense", | |
| ], | |
| # Movement & Speed | |
| "movement": [ | |
| "fast", "quick", "rapid", "jerky", "abrupt", "sudden", "darting", | |
| "twitchy", "erratic", "spasmodic", "snappy", "jagged-motion", | |
| ], | |
| # Emotion & Energy (Cross-modal) | |
| "emotion": [ | |
| "tense", "anxious", "nervous", "stressed", "agitated", "alert", "aggressive", | |
| "hostile", "angry", "irritable", "fierce", "intense", "urgent", "frantic", | |
| ], | |
| # Abstract Qualities | |
| "abstract": [ | |
| "harsh", "hard", "rigid", "stiff", "brittle", "crisp", "precise", | |
| "exact", "strict", "severe", "stern", "unforgiving", "dangerous", "threatening", | |
| ], | |
| } | |
| # Bouba anchors organized by domain (Rounded / Soft / Gentle) | |
| BOUBA_ANCHORS_BY_DOMAIN = { | |
| # Shape & Geometry (Primary - highest confidence) | |
| "shape_primary": [ | |
| "round", "rounded", "circular", "curved", "bulbous", "spherical", "globular", | |
| "oval", "elliptical", "undulating", "wavy", "flowing", "organic", "amorphous", | |
| "blobby", "puffy", "billowy", "domed", "arched", "swooping", | |
| ], | |
| # Texture (High confidence) | |
| "texture": [ | |
| "soft", "smooth", "silky", "velvety", "plush", "fluffy", "fuzzy", | |
| "downy", "cottony", "cushiony", "spongy", "supple", "tender", "delicate", "gentle", | |
| ], | |
| # Taste & Flavor (Cross-modal - validated) | |
| "taste": [ | |
| "sweet", "creamy", "mild", "mellow", "bland", "buttery", "rich", | |
| "chocolatey", "caramel", "honeyed", "sugary", "milky", "vanilla", "smooth-tasting", | |
| ], | |
| # Color & Light (Cross-modal) | |
| "color_light": [ | |
| "dim", "muted", "pastel", "soft-lit", "diffuse", "hazy", "foggy", | |
| "misty", "dusky", "twilight", "warm", "golden", "amber", "blue", "purple", | |
| ], | |
| # Sound (Cross-modal - validated) | |
| "sound": [ | |
| "low-pitched", "deep", "resonant", "humming", "droning", "murmuring", | |
| "rumbling", "melodic", "flowing", "legato", "muffled", "soft-sounding", "gentle", "soothing", | |
| ], | |
| # Temperature & Sensation | |
| "sensation": [ | |
| "warm", "lukewarm", "cozy", "comfortable", "soothing", "relaxing", | |
| "calming", "gentle", "caressing", "embracing", | |
| ], | |
| # Movement & Speed | |
| "movement": [ | |
| "slow", "gradual", "languid", "lazy", "drifting", "floating", | |
| "gliding", "flowing", "swaying", "undulating", "graceful", "smooth-motion", | |
| ], | |
| # Emotion & Energy (Cross-modal) | |
| "emotion": [ | |
| "calm", "peaceful", "relaxed", "serene", "tranquil", "content", "happy", | |
| "joyful", "gentle", "kind", "friendly", "welcoming", "safe", "comforting", | |
| ], | |
| # Abstract Qualities | |
| "abstract": [ | |
| "soft", "gentle", "flexible", "yielding", "malleable", "pliable", "forgiving", | |
| "lenient", "easygoing", "laid-back", "nurturing", "maternal", "protective", "embracing", | |
| ], | |
| } | |
| # Default domain weights (can be overridden in __init__) | |
| # With centered cosine similarity, weights are more balanced since generic attractors are removed | |
| DEFAULT_DOMAIN_WEIGHTS = { | |
| "shape_primary": 1.5, # Highest - most visually relevant | |
| "texture": 1.3, # High - directly visual | |
| "color_light": 0.6, # Reduced - often background-influenced | |
| "taste": 0.9, # Medium - cross-modal but validated | |
| "sound": 0.7, # Lower - less visually relevant | |
| "sensation": 0.8, # Lower - abstract | |
| "movement": 1.0, # Medium - can be visual | |
| "emotion": 0.6, # Lower - very abstract | |
| "abstract": 0.7, # Lower - very abstract | |
| } | |
| # Flattened anchor lists (computed from domains, deduplicated) | |
| KIKI_ANCHORS = None # Will be computed in __init__ | |
| BOUBA_ANCHORS = None # Will be computed in __init__ | |
| def _flatten_anchors(cls, anchors_by_domain: dict) -> Tuple[List[str], List[str]]: | |
| """ | |
| Flatten domain-organized anchors into a deduplicated list. | |
| Returns: | |
| Tuple of (flattened_anchors_list, domain_labels_list) where domain_labels_list | |
| maps each anchor to its domain. | |
| """ | |
| seen = set() | |
| result = [] | |
| domain_labels = [] | |
| for domain, domain_anchors in anchors_by_domain.items(): | |
| for anchor in domain_anchors: | |
| if anchor not in seen: | |
| seen.add(anchor) | |
| result.append(anchor) | |
| domain_labels.append(domain) | |
| return result, domain_labels | |
| def __init__(self, model_id: str = "openai/clip-vit-large-patch14", | |
| domain_weights: Optional[Dict[str, float]] = None): | |
| """ | |
| Initialize the classifier with a vision-language model. | |
| Args: | |
| model_id: HuggingFace model identifier (default: CLIP ViT-Large) | |
| domain_weights: Optional dict mapping domain names to weights. | |
| If None, uses DEFAULT_DOMAIN_WEIGHTS. | |
| Weights control the influence of each domain on classification. | |
| """ | |
| # Always initialize on CPU for ZeroGPU compatibility | |
| # ZeroGPU allocates GPU on-demand, and GPU context changes between requests | |
| # ensure_device() will move to GPU when classify() is called | |
| self.device = "cpu" | |
| print(f"Initializing on device: {self.device}") | |
| # Load model and processor | |
| print(f"Loading model: {model_id}") | |
| # Use CLIPModel/CLIPProcessor for CLIP models, AutoModel/AutoProcessor for SigLIP | |
| if "clip" in model_id.lower(): | |
| try: | |
| self.model = CLIPModel.from_pretrained(model_id) | |
| self.processor = CLIPProcessor.from_pretrained(model_id) | |
| print(f"Loaded CLIPModel - has get_text_features: {hasattr(self.model, 'get_text_features')}") | |
| print(f"Loaded CLIPModel - has get_image_features: {hasattr(self.model, 'get_image_features')}") | |
| except Exception as e: | |
| print(f"Warning: Failed to load as CLIPModel, trying AutoModel: {e}") | |
| self.model = AutoModel.from_pretrained(model_id) | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| else: | |
| self.model = AutoModel.from_pretrained(model_id) | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| # Move model to device and set to evaluation mode | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Set domain weights (use defaults if not provided) | |
| if domain_weights is None: | |
| self.domain_weights = self.DEFAULT_DOMAIN_WEIGHTS.copy() | |
| else: | |
| # Merge with defaults, allowing partial overrides | |
| self.domain_weights = self.DEFAULT_DOMAIN_WEIGHTS.copy() | |
| self.domain_weights.update(domain_weights) | |
| # Flatten and deduplicate anchor lists from domain dictionaries | |
| # Also track which domain each anchor belongs to | |
| self.kiki_anchors, self.kiki_anchor_domains = self._flatten_anchors(self.KIKI_ANCHORS_BY_DOMAIN) | |
| self.bouba_anchors, self.bouba_anchor_domains = self._flatten_anchors(self.BOUBA_ANCHORS_BY_DOMAIN) | |
| print(f"Using {len(self.kiki_anchors)} Kiki anchors and {len(self.bouba_anchors)} Bouba anchors") | |
| print(f"Anchor domains: {list(self.KIKI_ANCHORS_BY_DOMAIN.keys())}") | |
| print(f"Domain weights: {self.domain_weights}") | |
| # Pre-compute text anchor embeddings | |
| print("Pre-computing text anchor embeddings...") | |
| self.kiki_embeddings = self._embed_texts(self.kiki_anchors) | |
| self.bouba_embeddings = self._embed_texts(self.bouba_anchors) | |
| print("Classifier ready!") | |
| def ensure_device(self, device: str): | |
| """ | |
| Ensure model and embeddings are on the specified device. | |
| Called before inference to handle ZeroGPU device switching. | |
| Args: | |
| device: Target device ('cuda' or 'cpu') | |
| """ | |
| # Always move to current device - required for ZeroGPU where GPU context changes between requests | |
| # Even if self.device == device, the tensors may be on a stale GPU context | |
| if self.device != device: | |
| print(f"Moving classifier from {self.device} to {device}") | |
| self.device = device | |
| self.model.to(device) | |
| self.kiki_embeddings = self.kiki_embeddings.to(device) | |
| self.bouba_embeddings = self.bouba_embeddings.to(device) | |
| def _embed_texts(self, texts: List[str]) -> torch.Tensor: | |
| """ | |
| Encode text anchors into normalized embeddings. | |
| Args: | |
| texts: List of text strings to embed | |
| Returns: | |
| Normalized text embeddings tensor | |
| """ | |
| inputs = self.processor(text=texts, return_tensors="pt", padding=True) | |
| # Debug: print what keys the processor returns | |
| print(f"Text processor returned keys: {list(inputs.keys())}") | |
| # Explicitly extract only the keys text_model needs | |
| input_ids = inputs['input_ids'].to(self.device) | |
| attention_mask = inputs['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| # Use text_model with only the required inputs | |
| text_outputs = self.model.text_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| text_embeds = text_outputs.pooler_output | |
| embeddings = self.model.text_projection(text_embeds) | |
| return F.normalize(embeddings, dim=-1) | |
| def _embed_image(self, image: Union[Image.Image, str]) -> torch.Tensor: | |
| """ | |
| Encode an image into a normalized embedding. | |
| Args: | |
| image: PIL Image or path to image file | |
| Returns: | |
| Normalized image embedding tensor | |
| """ | |
| # Handle string paths | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| # Debug: print what keys the processor returns | |
| print(f"Image processor returned keys: {list(inputs.keys())}") | |
| # Explicitly extract only pixel_values | |
| pixel_values = inputs['pixel_values'].to(self.device) | |
| with torch.no_grad(): | |
| # Use vision_model with only the required input | |
| vision_outputs = self.model.vision_model( | |
| pixel_values=pixel_values, | |
| return_dict=True | |
| ) | |
| image_embeds = vision_outputs.pooler_output | |
| embedding = self.model.visual_projection(image_embeds) | |
| return F.normalize(embedding, dim=-1) | |
| def _compute_domain_scores(self, similarities: torch.Tensor, anchor_domains: List[str], | |
| top_k_per_domain: int = 3) -> Dict[str, float]: | |
| """ | |
| Compute per-domain scores by grouping similarities by domain and taking top-K mean. | |
| Args: | |
| similarities: Tensor of similarity scores for all anchors | |
| anchor_domains: List of domain names, one per anchor (same length as similarities) | |
| top_k_per_domain: Number of top anchors to use per domain (default: 3) | |
| Returns: | |
| Dictionary mapping domain names to their mean top-K scores | |
| """ | |
| domain_scores = {} | |
| # Group similarities by domain | |
| domain_groups = {} | |
| for i, domain in enumerate(anchor_domains): | |
| if domain not in domain_groups: | |
| domain_groups[domain] = [] | |
| domain_groups[domain].append(i) | |
| # Compute top-K mean for each domain | |
| for domain, indices in domain_groups.items(): | |
| # Convert indices list to tensor for proper indexing | |
| indices_tensor = torch.tensor(indices, device=similarities.device) | |
| domain_sims = similarities[indices_tensor] | |
| k = min(top_k_per_domain, len(domain_sims)) | |
| if k > 0: | |
| domain_scores[domain] = domain_sims.topk(k=k).values.mean().item() | |
| else: | |
| domain_scores[domain] = 0.0 | |
| return domain_scores | |
| def _apply_domain_weights(self, domain_scores: Dict[str, float]) -> float: | |
| """ | |
| Apply domain weights to domain scores and compute weighted average. | |
| Args: | |
| domain_scores: Dictionary mapping domain names to their scores | |
| Returns: | |
| Weighted average score across all domains | |
| """ | |
| weighted_sum = 0.0 | |
| weight_sum = 0.0 | |
| for domain, score in domain_scores.items(): | |
| weight = self.domain_weights.get(domain, 0.0) | |
| if weight > 0: # Skip domains with zero or negative weight | |
| weighted_sum += score * weight | |
| weight_sum += weight | |
| # If all weights are zero, fall back to equal weighting | |
| if weight_sum == 0: | |
| return sum(domain_scores.values()) / len(domain_scores) if domain_scores else 0.0 | |
| return weighted_sum / weight_sum | |
| def classify(self, image: Union[Image.Image, str], top_k: int = 10) -> Dict: | |
| """ | |
| Classify an image as Kiki or Bouba using domain-weighted top-K scoring and difference-based spectrum. | |
| Args: | |
| image: PIL Image or path to image file | |
| top_k: Approximate number of top-scoring anchors to use (default: 10) | |
| This is divided across domains, with top-K per domain computed separately. | |
| Domain weights are then applied to combine domain scores. | |
| Returns: | |
| Dictionary containing: | |
| - kiki_score: Weighted average similarity to top-K Kiki anchors across domains | |
| - bouba_score: Weighted average similarity to top-K Bouba anchors across domains | |
| - spectrum_position: Position on 0-1 spectrum (0=Kiki, 1=Bouba) based on percentage difference | |
| - classification: "Kiki", "Neutral", or "Bouba" | |
| - confidence: Confidence score (0-1) | |
| - kiki_domain_scores: Dict mapping domain names to their Kiki scores | |
| - bouba_domain_scores: Dict mapping domain names to their Bouba scores | |
| - kiki_anchor_scores: Dict of individual Kiki anchor similarities | |
| - bouba_anchor_scores: Dict of individual Bouba anchor similarities | |
| """ | |
| # Ensure all components are on the current device | |
| # This handles ZeroGPU device switching | |
| current_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.ensure_device(current_device) | |
| # Encode image | |
| image_emb = self._embed_image(image) | |
| # Calculate cosine similarities with anchor embeddings (standard approach) | |
| # CLIP embeddings are normalized, so this computes cosine similarity | |
| kiki_sims = (image_emb @ self.kiki_embeddings.T).squeeze() | |
| bouba_sims = (image_emb @ self.bouba_embeddings.T).squeeze() | |
| # Compute per-domain scores with top-K within each domain | |
| # This ensures strong matches within each domain aren't diluted, and allows domain weighting | |
| top_k_per_domain = max(1, top_k // 3) # Use top-K per domain (roughly 3-4 anchors per domain) | |
| kiki_domain_scores = self._compute_domain_scores( | |
| kiki_sims, self.kiki_anchor_domains, top_k_per_domain=top_k_per_domain | |
| ) | |
| bouba_domain_scores = self._compute_domain_scores( | |
| bouba_sims, self.bouba_anchor_domains, top_k_per_domain=top_k_per_domain | |
| ) | |
| # Apply domain weights and compute weighted average | |
| kiki_score = self._apply_domain_weights(kiki_domain_scores) | |
| bouba_score = self._apply_domain_weights(bouba_domain_scores) | |
| # Calculate percentage-based difference spectrum position (0 = pure Kiki, 1 = pure Bouba) | |
| # This approach normalizes for score magnitude, making it sensitive to relative differences | |
| # rather than absolute values, which adapts better to different similarity ranges | |
| # Calculate average score for normalization | |
| avg_score = (kiki_score + bouba_score) / 2 | |
| # Calculate percentage difference (normalizes for overall similarity magnitude) | |
| if avg_score > 0: | |
| percent_diff = (bouba_score - kiki_score) / avg_score | |
| else: | |
| percent_diff = 0.0 # Edge case: both scores are zero | |
| # Scale factor for percentage differences (tuned for CLIP embeddings) | |
| # CLIP typically produces higher absolute similarities than SigLIP | |
| # A 25% difference maps to full spectrum (0.0 or 1.0) | |
| scale_factor = 0.25 | |
| # Convert percentage difference to spectrum position: 0.5 (neutral) + scaled percent_diff | |
| spectrum_position = 0.5 + (percent_diff / scale_factor) | |
| # Clamp to [0, 1] range | |
| spectrum_position = max(0.0, min(1.0, spectrum_position)) | |
| # Determine classification with neutral zone | |
| # Kiki: 0.0-0.4, Neutral: 0.4-0.6, Bouba: 0.6-1.0 | |
| if spectrum_position < 0.4: | |
| classification = "Kiki" | |
| elif spectrum_position > 0.6: | |
| classification = "Bouba" | |
| else: | |
| classification = "Neutral" | |
| # Calculate confidence | |
| # For Kiki/Bouba: distance from center (0.5), scaled to 0-1 | |
| # For Neutral: how close to center (0.5), scaled to 0-1 | |
| if classification == "Neutral": | |
| # Confidence is how close to 0.5 (center of neutral zone) | |
| confidence = 1.0 - abs(spectrum_position - 0.5) * 5 # Scale so 0.5 = 1.0, 0.4/0.6 = 0.5 | |
| confidence = max(0.0, min(1.0, confidence)) # Clamp to [0, 1] | |
| else: | |
| # Distance from center, scaled to 0-1 | |
| confidence = abs(spectrum_position - 0.5) * 2 | |
| # Create anchor score dictionaries | |
| kiki_anchor_scores = dict(zip(self.kiki_anchors, kiki_sims.cpu().tolist())) | |
| bouba_anchor_scores = dict(zip(self.bouba_anchors, bouba_sims.cpu().tolist())) | |
| return { | |
| "kiki_score": kiki_score, | |
| "bouba_score": bouba_score, | |
| "spectrum_position": spectrum_position, # 0=Kiki, 1=Bouba (percentage-based) | |
| "classification": classification, | |
| "confidence": confidence, | |
| "kiki_domain_scores": kiki_domain_scores, | |
| "bouba_domain_scores": bouba_domain_scores, | |
| "kiki_anchor_scores": kiki_anchor_scores, | |
| "bouba_anchor_scores": bouba_anchor_scores | |
| } | |