""" Freckles Center-Mass Interception — Full Geometric Alignment Battery ====================================================================== Intercept every stage of the PatchSVAE pipeline and measure everything. Pipeline stages intercepted: 1. Raw patches (B, N, 48) 2. Encoder hidden states (B*N, 384) per block 3. Pre-SVD matrix M (B, N, 48, 4) 4. SVD components: U, S_orig, Vt (the bottleneck) 5. Cross-attention: S_in → S_out (the coordination) 6. Attention weights (B, 2, N, N) per layer 7. Reconstructed matrix M_hat (B*N, 192) 8. Decoder hidden states (B*N, 384) per block 9. Decoded patches (B, N, 48) 10. Stitched + boundary smooth (B, 3, H, W) Metrics at each stage: - Effective rank, condition number - Singular value spectrum - Spearman rank correlation with input/output - Procrustes alignment between stages - CV (coefficient of variation of volumes) - Cosine similarity distributions - Gradient magnitude (if training) - Information retention ratio - Dead neuron / activation statistics Usage: python freckles_observer.py --model v40_freckles_noise """ import os import math import json import time import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from collections import OrderedDict from scipy import stats as scipy_stats # ═══════════════════════════════════════════════════════════════ # HOOK-BASED INTERCEPTOR # ═══════════════════════════════════════════════════════════════ class PipelineInterceptor: """Attach hooks to every stage of PatchSVAE, capture activations.""" def __init__(self, model): self.model = model self.captures = OrderedDict() self.hooks = [] self._attach_hooks() def _attach_hooks(self): m = self.model # Encoder input projection def hook_enc_in(module, inp, out): self.captures['enc_in'] = out.detach() self.hooks.append(m.enc_in.register_forward_hook(hook_enc_in)) # Encoder blocks for i, block in enumerate(m.enc_blocks): def make_hook(idx): def hook(module, inp, out): self.captures[f'enc_block_{idx}'] = out.detach() return hook self.hooks.append(block.register_forward_hook(make_hook(i))) # Encoder output (pre-normalize, pre-SVD) def hook_enc_out(module, inp, out): self.captures['enc_out_raw'] = out.detach() self.hooks.append(m.enc_out.register_forward_hook(hook_enc_out)) # Cross-attention layers for i, layer in enumerate(m.cross_attn): def make_ca_hook(idx): def hook(module, inp, out): self.captures[f'cross_attn_{idx}_in'] = inp[0].detach() self.captures[f'cross_attn_{idx}_out'] = out.detach() return hook self.hooks.append(layer.register_forward_hook(make_ca_hook(i))) # QKV hook for attention weights def make_qkv_hook(idx): def hook(module, inp, out): self.captures[f'cross_attn_{idx}_qkv'] = out.detach() return hook self.hooks.append(layer.qkv.register_forward_hook(make_qkv_hook(i))) # Decoder input def hook_dec_in(module, inp, out): self.captures['dec_in'] = out.detach() self.hooks.append(m.dec_in.register_forward_hook(hook_dec_in)) # Decoder blocks for i, block in enumerate(m.dec_blocks): def make_hook(idx): def hook(module, inp, out): self.captures[f'dec_block_{idx}'] = out.detach() return hook self.hooks.append(block.register_forward_hook(make_hook(i))) # Decoder output def hook_dec_out(module, inp, out): self.captures['dec_out'] = out.detach() self.hooks.append(m.dec_out.register_forward_hook(hook_dec_out)) # Boundary smooth def hook_boundary(module, inp, out): self.captures['boundary_in'] = inp[0].detach() self.captures['boundary_out'] = out.detach() self.hooks.append(m.boundary_smooth.register_forward_hook(hook_boundary)) def clear(self): self.captures.clear() def remove_hooks(self): for h in self.hooks: h.remove() self.hooks.clear() @torch.no_grad() def run(self, images): """Forward pass with full interception.""" self.clear() out = self.model(images) # Also capture SVD components from the output self.captures['svd_U'] = out['svd']['U'].detach() self.captures['svd_S_orig'] = out['svd']['S_orig'].detach() self.captures['svd_S'] = out['svd']['S'].detach() self.captures['svd_Vt'] = out['svd']['Vt'].detach() self.captures['svd_M'] = out['svd']['M'].detach() self.captures['recon'] = out['recon'].detach() self.captures['input'] = images.detach() return out # ═══════════════════════════════════════════════════════════════ # GEOMETRIC METRICS # ═══════════════════════════════════════════════════════════════ def effective_rank(X): """Effective rank of a matrix via Shannon entropy of singular values.""" if X.dim() == 1: X = X.unsqueeze(0) if X.dim() == 2: _, S, _ = torch.linalg.svd(X.float(), full_matrices=False) else: S = X.float() p = S / (S.sum(-1, keepdim=True) + 1e-8) p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() def condition_number(X): """Condition number from singular values.""" if X.dim() < 2: return torch.tensor(1.0) try: S = torch.linalg.svdvals(X.float()) return (S[..., 0] / (S[..., -1] + 1e-10)).mean().item() except: return float('nan') def spearman_rank(a, b): """Spearman rank correlation between two flattened tensors.""" a_np = a.flatten().cpu().numpy() b_np = b.flatten().cpu().numpy() n = min(len(a_np), len(b_np), 100000) if n < len(a_np): a_np = a_np[:n] if n < len(b_np): b_np = b_np[:n] try: r, p = scipy_stats.spearmanr(a_np, b_np) return r except: return float('nan') def procrustes_alignment(A, B): """Procrustes alignment score between two (N, D) matrices. Returns: rotation error (lower = more aligned), scale ratio. """ if A.shape != B.shape: n = min(A.shape[0], B.shape[0]) A, B = A[:n], B[:n] A_c = A - A.mean(0, keepdim=True) B_c = B - B.mean(0, keepdim=True) A_n = A_c / (A_c.norm() + 1e-8) B_n = B_c / (B_c.norm() + 1e-8) M = A_n.T @ B_n U, S, Vt = torch.linalg.svd(M.float()) R = U @ Vt aligned = B_n @ R.T error = (A_n - aligned).pow(2).mean().item() alignment = S.sum().item() return {'error': error, 'alignment': alignment, 'scale_ratio': (A_c.norm() / (B_c.norm() + 1e-8)).item()} def cosine_sim_distribution(X): """Pairwise cosine similarity statistics for (N, D) matrix.""" X_n = F.normalize(X.float(), dim=-1) n = min(X_n.shape[0], 500) X_n = X_n[:n] sim = X_n @ X_n.T mask = ~torch.eye(n, dtype=torch.bool, device=sim.device) vals = sim[mask] return { 'mean': vals.mean().item(), 'std': vals.std().item(), 'min': vals.min().item(), 'max': vals.max().item(), 'median': vals.median().item(), } def activation_stats(X): """Activation statistics for a hidden state tensor.""" X_flat = X.float().flatten() return { 'mean': X_flat.mean().item(), 'std': X_flat.std().item(), 'min': X_flat.min().item(), 'max': X_flat.max().item(), 'abs_mean': X_flat.abs().mean().item(), 'dead_frac': (X_flat.abs() < 1e-6).float().mean().item(), 'sparsity': (X_flat == 0).float().mean().item(), 'kurtosis': ((X_flat - X_flat.mean()) / (X_flat.std() + 1e-8)).pow(4).mean().item() - 3, } def cayley_menger_vol2(points): B, N, D = points.shape pts = points.double() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) def cv_of(emb, n_samples=200): if emb.dim() != 2 or emb.shape[0] < 5: return float('nan') N, D = emb.shape pool = min(N, 512) indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) vol2 = cayley_menger_vol2(emb[:pool][indices]) valid = vol2 > 1e-20 if valid.sum() < 10: return float('nan') vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # ═══════════════════════════════════════════════════════════════ # ATTENTION ANALYSIS # ═══════════════════════════════════════════════════════════════ def analyze_attention(qkv, n_heads, D): """Analyze attention patterns from QKV tensor.""" # qkv shape varies: (B, N, 3*D) or (B*N, 3*D) if qkv.dim() == 3: qkv = qkv.reshape(-1, qkv.shape[-1]) BN, three_D = qkv.shape # We need B and N — estimate from context # For now analyze the raw QKV statistics head_dim = D // n_heads qkv_r = qkv.reshape(-1, 3, D) q, k, v = qkv_r[:, 0], qkv_r[:, 1], qkv_r[:, 2] return { 'q_norm_mean': q.norm(dim=-1).mean().item(), 'k_norm_mean': k.norm(dim=-1).mean().item(), 'v_norm_mean': v.norm(dim=-1).mean().item(), 'q_std': q.std().item(), 'k_std': k.std().item(), 'v_std': v.std().item(), 'qk_cosine': F.cosine_similarity(q, k, dim=-1).mean().item(), 'qv_cosine': F.cosine_similarity(q, v, dim=-1).mean().item(), 'kv_cosine': F.cosine_similarity(k, v, dim=-1).mean().item(), } def analyze_cross_attn_delta(S_in, S_out, alpha_logits, max_alpha=0.2): """Analyze what cross-attention actually changed.""" delta = S_out - S_in alpha = max_alpha * torch.sigmoid(alpha_logits) return { 'delta_abs_mean': delta.abs().mean().item(), 'delta_abs_max': delta.abs().max().item(), 'delta_std': delta.std().item(), 'delta_per_mode': delta.abs().mean(dim=(0, 1)).tolist(), 'alpha_values': alpha.tolist(), 'alpha_mean': alpha.mean().item(), 'relative_change': (delta.abs().mean() / (S_in.abs().mean() + 1e-8)).item(), 'sign_agreement': (delta.sign() == S_in.sign()).float().mean().item(), } # ═══════════════════════════════════════════════════════════════ # SVD BOTTLENECK ANALYSIS # ═══════════════════════════════════════════════════════════════ def analyze_svd_bottleneck(U, S, Vt, M): """Deep analysis of the SVD bottleneck.""" B, N, V, D = U.shape # Singular value spectrum S_flat = S.reshape(-1, D) S_mean = S_flat.mean(0) S_std = S_flat.std(0) # Reconstruction quality: M vs U @ diag(S) @ Vt M_recon = torch.bmm( U.reshape(B*N, V, D) * S.reshape(B*N, D).unsqueeze(1), Vt.reshape(B*N, D, D)) M_flat = M.reshape(B*N, V, D) recon_error = (M_recon - M_flat).pow(2).mean().item() # Orthogonality of U columns U_flat = U.reshape(B*N, V, D) UtU = torch.bmm(U_flat.transpose(1, 2), U_flat) eye = torch.eye(D, device=U.device).unsqueeze(0) ortho_error = (UtU - eye).pow(2).mean().item() # Orthogonality of Vt rows Vt_flat = Vt.reshape(B*N, D, D) VtVt = torch.bmm(Vt_flat, Vt_flat.transpose(1, 2)) vt_ortho_error = (VtVt - eye).pow(2).mean().item() # Condition number of M cond = (S_mean[0] / (S_mean[-1] + 1e-8)).item() # Effective rank erank = effective_rank(S_flat).mean().item() # Energy distribution S2 = S_flat.pow(2) energy = S2 / (S2.sum(-1, keepdim=True) + 1e-8) energy_mean = energy.mean(0) # Sphere radius (norm of M rows after normalization) M_norms = M_flat.reshape(B*N*V, D).norm(dim=-1) return { 'S_mean': S_mean.tolist(), 'S_std': S_std.tolist(), 'S_ratio': (S_mean[0] / (S_mean[-1] + 1e-8)).item(), 'condition_number': cond, 'effective_rank': erank, 'recon_error': recon_error, 'U_orthogonality_error': ortho_error, 'Vt_orthogonality_error': vt_ortho_error, 'energy_per_mode': energy_mean.tolist(), 'sphere_radius_mean': M_norms.mean().item(), 'sphere_radius_std': M_norms.std().item(), } # ═══════════════════════════════════════════════════════════════ # ENCODER/DECODER SYMMETRY ANALYSIS # ═══════════════════════════════════════════════════════════════ def analyze_enc_dec_symmetry(interceptor): """Compare encoder and decoder at each stage.""" caps = interceptor.captures results = {} model = interceptor.model depth = len(model.enc_blocks) for i in range(depth): enc_key = f'enc_block_{i}' dec_key = f'dec_block_{depth - 1 - i}' # Mirror order if enc_key in caps and dec_key in caps: enc = caps[enc_key].reshape(-1, caps[enc_key].shape[-1]) dec = caps[dec_key].reshape(-1, caps[dec_key].shape[-1]) n = min(enc.shape[0], dec.shape[0], 10000) enc_s, dec_s = enc[:n], dec[:n] results[f'block_{i}_spearman'] = spearman_rank(enc_s, dec_s) results[f'block_{i}_cosine'] = F.cosine_similarity( enc_s.mean(0, keepdim=True), dec_s.mean(0, keepdim=True)).item() proc = procrustes_alignment(enc_s[:500].cpu(), dec_s[:500].cpu()) results[f'block_{i}_procrustes_error'] = proc['error'] results[f'block_{i}_procrustes_alignment'] = proc['alignment'] return results # ═══════════════════════════════════════════════════════════════ # INFORMATION FLOW ANALYSIS # ═══════════════════════════════════════════════════════════════ def analyze_information_flow(interceptor): """Track how information transforms through the pipeline.""" caps = interceptor.captures results = {} # Input → Encoder output if 'input' in caps and 'enc_out_raw' in caps: inp = caps['input'].reshape(caps['input'].shape[0], -1) enc = caps['enc_out_raw'].reshape(caps['enc_out_raw'].shape[0], -1) n = min(inp.shape[0], enc.shape[0]) results['input_to_enc_spearman'] = spearman_rank(inp[:n], enc[:n]) # Pre-SVD M → Post-SVD S if 'svd_M' in caps and 'svd_S_orig' in caps: M = caps['svd_M'].reshape(-1, caps['svd_M'].shape[-1]) S = caps['svd_S_orig'].reshape(-1, caps['svd_S_orig'].shape[-1]) results['M_to_S_compression'] = M.shape[-1] / S.shape[-1] # How much variance does S capture of M? M_var = M.var().item() S_var = S.var().item() results['M_variance'] = M_var results['S_variance'] = S_var results['variance_retention'] = S_var / (M_var + 1e-8) # S_orig → S (cross-attention effect) if 'svd_S_orig' in caps and 'svd_S' in caps: S_orig = caps['svd_S_orig'] S_coord = caps['svd_S'] delta = (S_coord - S_orig).abs() results['cross_attn_total_delta'] = delta.mean().item() results['cross_attn_max_delta'] = delta.max().item() results['cross_attn_relative_delta'] = (delta.mean() / (S_orig.abs().mean() + 1e-8)).item() # Decoder output → Reconstruction if 'dec_out' in caps and 'recon' in caps: dec = caps['dec_out'].reshape(-1) recon = caps['recon'].reshape(-1) n = min(len(dec), len(recon), 100000) results['dec_to_recon_spearman'] = spearman_rank(dec[:n], recon[:n]) # Input → Reconstruction (end-to-end) if 'input' in caps and 'recon' in caps: inp = caps['input'].reshape(-1) recon = caps['recon'].reshape(-1) n = min(len(inp), len(recon), 100000) results['end_to_end_spearman'] = spearman_rank(inp[:n], recon[:n]) results['end_to_end_mse'] = F.mse_loss( caps['recon'], caps['input']).item() # Boundary smooth effect if 'boundary_in' in caps and 'boundary_out' in caps: b_in = caps['boundary_in'] b_out = caps['boundary_out'] b_delta = (b_out - b_in).abs() results['boundary_delta_mean'] = b_delta.mean().item() results['boundary_delta_max'] = b_delta.max().item() results['boundary_relative'] = (b_delta.mean() / (b_in.abs().mean() + 1e-8)).item() return results # ═══════════════════════════════════════════════════════════════ # CV AT EVERY STAGE # ═══════════════════════════════════════════════════════════════ def analyze_cv_at_stages(interceptor): """Compute CV at key representation stages.""" caps = interceptor.captures results = {} stages = [ ('enc_in', 'Encoder input projection'), ('enc_block_0', 'Encoder block 0'), ('enc_block_1', 'Encoder block 1'), ('enc_block_2', 'Encoder block 2'), ('enc_block_3', 'Encoder block 3'), ('svd_S_orig', 'SVD S (pre cross-attn)'), ('svd_S', 'SVD S (post cross-attn)'), ('dec_in', 'Decoder input projection'), ('dec_block_0', 'Decoder block 0'), ('dec_block_3', 'Decoder block 3'), ] for key, name in stages: if key not in caps: continue X = caps[key] # Reshape to (N, D) for CV computation if X.dim() == 4: # (B, N, V, D) X = X.reshape(-1, X.shape[-1]) elif X.dim() == 3: # (B, N, D) or (B*N, ...) X = X.reshape(-1, X.shape[-1]) elif X.dim() == 2: pass else: continue if X.shape[-1] > 50: # PCA down for CV computation X_c = X - X.mean(0, keepdim=True) _, _, V = torch.linalg.svd(X_c[:min(1000, len(X_c))].float(), full_matrices=False) X = X @ V[:16].T try: cv = cv_of(X[:500].float()) results[key] = {'cv': cv, 'name': name, 'dim': X.shape[-1]} except: results[key] = {'cv': float('nan'), 'name': name, 'dim': X.shape[-1]} return results # ═══════════════════════════════════════════════════════════════ # WEIGHT ANALYSIS # ═══════════════════════════════════════════════════════════════ def analyze_weights(model): """Analyze model weight properties.""" results = {} for name, param in model.named_parameters(): p = param.data.float() entry = { 'shape': list(p.shape), 'norm': p.norm().item(), 'mean': p.mean().item(), 'std': p.std().item(), 'abs_mean': p.abs().mean().item(), 'sparsity': (p.abs() < 1e-6).float().mean().item(), } if p.dim() == 2: entry['condition'] = condition_number(p) entry['erank'] = effective_rank(p).item() results[name] = entry return results # ═══════════════════════════════════════════════════════════════ # NOISE TYPE FINGERPRINT ANALYSIS # ═══════════════════════════════════════════════════════════════ def _gen_noise(t, s): """Minimal noise gen for fingerprint analysis.""" if t == 0: return torch.randn(3, s, s) elif t == 1: return torch.rand(3, s, s) * 2 - 1 elif t == 4: w = torch.randn(3, s, s) S = torch.fft.rfft2(w) h, ww = s, s fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) img = torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) return img / (img.std() + 1e-8) elif t == 6: return torch.where(torch.rand(3, s, s) > 0.5, torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1 elif t == 13: return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3) return torch.randn(3, s, s) def analyze_noise_fingerprints(interceptor, device, n_per=16): """How do different noise types look at each pipeline stage?""" results = {} type_names = {0: 'gaussian', 1: 'uniform', 4: 'pink', 6: 'salt_pepper', 13: 'cauchy'} for t, name in type_names.items(): imgs = torch.stack([_gen_noise(t, 64).clamp(-4, 4) for _ in range(n_per)]).to(device) interceptor.run(imgs) caps = interceptor.captures entry = {} # S profile per noise type if 'svd_S' in caps: S = caps['svd_S'] entry['S_mean'] = S.mean(dim=(0, 1)).tolist() entry['S_std'] = S.std(dim=(0, 1)).tolist() entry['erank'] = effective_rank(S.reshape(-1, S.shape[-1])).mean().item() # Encoder hidden activation pattern if 'enc_block_3' in caps: h = caps['enc_block_3'] entry['enc_final_abs_mean'] = h.abs().mean().item() entry['enc_final_dead_frac'] = (h.abs() < 1e-6).float().mean().item() # Cross-attention delta if 'svd_S_orig' in caps and 'svd_S' in caps: delta = (caps['svd_S'] - caps['svd_S_orig']).abs() entry['cross_attn_delta'] = delta.mean().item() # Reconstruction MSE if 'recon' in caps: entry['recon_mse'] = F.mse_loss(caps['recon'], imgs).item() results[name] = entry return results # ═══════════════════════════════════════════════════════════════ # FULL BATTERY # ═══════════════════════════════════════════════════════════════ @torch.no_grad() def run_full_battery(model, device, img_size=64, n_samples=64): """Run the complete center-mass interception battery.""" print("\n" + "=" * 70) print("FRECKLES CENTER-MASS INTERCEPTION") print("Full Geometric Alignment Battery") print("=" * 70) interceptor = PipelineInterceptor(model) D = model.D ps = model.patch_size # Generate test data (mixed noise types) imgs = [] for t in range(16): for _ in range(n_samples // 16): imgs.append(_gen_noise(t % 5 * 3, img_size).clamp(-4, 4)) imgs = torch.stack(imgs[:n_samples]).to(device) t0 = time.time() all_results = {} # ── 1. Forward with interception ── print("\n [1/8] Forward pass with interception...") out = interceptor.run(imgs) print(f" Captured {len(interceptor.captures)} activation tensors") for k, v in interceptor.captures.items(): print(f" {k:30s} {str(list(v.shape)):>30s}") # ── 2. Activation statistics at every stage ── print("\n [2/8] Activation statistics...") act_stats = {} for key, tensor in interceptor.captures.items(): if tensor.numel() > 0: act_stats[key] = activation_stats(tensor) all_results['activation_stats'] = act_stats for key in ['enc_in', 'enc_block_3', 'svd_S_orig', 'svd_S', 'dec_in', 'dec_block_3', 'dec_out']: if key in act_stats: s = act_stats[key] print(f" {key:25s} mean={s['mean']:+.4f} std={s['std']:.4f} " f"dead={s['dead_frac']:.3f} kurt={s['kurtosis']:.2f}") # ── 3. SVD bottleneck analysis ── print("\n [3/8] SVD bottleneck analysis...") caps = interceptor.captures svd_analysis = analyze_svd_bottleneck( caps['svd_U'], caps['svd_S_orig'], caps['svd_Vt'], caps['svd_M']) all_results['svd_bottleneck'] = svd_analysis print(f" S spectrum: {['%.3f' % s for s in svd_analysis['S_mean']]}") print(f" S ratio (S0/SD): {svd_analysis['S_ratio']:.3f}") print(f" Effective rank: {svd_analysis['effective_rank']:.3f}") print(f" U orthogonality error: {svd_analysis['U_orthogonality_error']:.6f}") print(f" Vt orthogonality error: {svd_analysis['Vt_orthogonality_error']:.6f}") print(f" Energy per mode: {['%.3f' % e for e in svd_analysis['energy_per_mode']]}") print(f" Sphere radius: {svd_analysis['sphere_radius_mean']:.4f} ± {svd_analysis['sphere_radius_std']:.4f}") # ── 4. Cross-attention analysis ── print("\n [4/8] Cross-attention analysis...") ca_results = {} for i in range(len(model.cross_attn)): in_key = f'cross_attn_{i}_in' out_key = f'cross_attn_{i}_out' qkv_key = f'cross_attn_{i}_qkv' if in_key in caps and out_key in caps: delta = analyze_cross_attn_delta( caps[in_key], caps[out_key], model.cross_attn[i].alpha_logits) ca_results[f'layer_{i}_delta'] = delta print(f" Layer {i}: delta={delta['delta_abs_mean']:.6f} " f"relative={delta['relative_change']:.4f} " f"alpha={delta['alpha_values']}") if qkv_key in caps: attn = analyze_attention(caps[qkv_key], model.cross_attn[i].n_heads, D) ca_results[f'layer_{i}_attention'] = attn print(f" Layer {i} QKV: q_norm={attn['q_norm_mean']:.3f} " f"qk_cos={attn['qk_cosine']:.3f} " f"kv_cos={attn['kv_cosine']:.3f}") all_results['cross_attention'] = ca_results # ── 5. Encoder/Decoder symmetry ── print("\n [5/8] Encoder/decoder symmetry...") sym = analyze_enc_dec_symmetry(interceptor) all_results['enc_dec_symmetry'] = sym for key, val in sorted(sym.items()): print(f" {key}: {val:.4f}") # ── 6. Information flow ── print("\n [6/8] Information flow analysis...") flow = analyze_information_flow(interceptor) all_results['information_flow'] = flow for key, val in sorted(flow.items()): print(f" {key}: {val:.6f}") # ── 7. CV at every stage ── print("\n [7/8] CV at pipeline stages...") cv_stages = analyze_cv_at_stages(interceptor) all_results['cv_stages'] = cv_stages for key, data in cv_stages.items(): print(f" {data['name']:30s} CV={data['cv']:.4f} dim={data['dim']}") # ── 8. Noise type fingerprints ── print("\n [8/8] Noise type fingerprints...") fingerprints = analyze_noise_fingerprints(interceptor, device) all_results['noise_fingerprints'] = fingerprints for name, fp in fingerprints.items(): print(f" {name:15s} S={['%.2f' % s for s in fp.get('S_mean', [])]}" f" er={fp.get('erank', 0):.2f}" f" Δca={fp.get('cross_attn_delta', 0):.5f}" f" mse={fp.get('recon_mse', 0):.6f}") # ── Weight analysis ── print("\n [BONUS] Weight analysis...") weights = analyze_weights(model) all_results['weights'] = weights total_params = sum(p.numel() for p in model.parameters()) total_sparse = sum(v['sparsity'] * np.prod(v['shape']) for v in weights.values()) print(f" Total params: {total_params:,}") print(f" Effective sparsity: {total_sparse / total_params:.4f}") # Key weight matrices for name in ['enc_out.weight', 'dec_in.weight', 'dec_out.weight']: if name in weights: w = weights[name] print(f" {name:25s} norm={w['norm']:.3f} cond={w.get('condition', 'N/A')}" f" erank={w.get('erank', 'N/A')}") elapsed = time.time() - t0 interceptor.remove_hooks() print(f"\n{'=' * 70}") print(f"BATTERY COMPLETE — {elapsed:.1f}s") print(f"{'=' * 70}") return all_results # ═══════════════════════════════════════════════════════════════ # LOAD + RUN # ═══════════════════════════════════════════════════════════════ def load_freckles(model_path=None, hf_version=None, device='cuda'): from geolip_svae import load_model if hf_version: return load_model(hf_version=hf_version, device=device) else: return load_model(checkpoint_path=model_path, device=device) if __name__ == "__main__": MODEL = 'v41_freckles_256' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, cfg = load_freckles(hf_version=MODEL, device=device) results = run_full_battery(model, device, img_size=cfg.get('img_size', 64)) # Save def to_json(obj): if isinstance(obj, (torch.Tensor, np.ndarray)): if hasattr(obj, 'tolist'): return obj.tolist() return float(obj) if isinstance(obj, dict): return {str(k): to_json(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [to_json(v) for v in obj] if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): return str(obj) return obj out_path = f'freckles_observer_{MODEL}.json' with open(out_path, 'w') as f: json.dump(to_json(results), f, indent=2) print(f" Saved: {out_path}")