| """ |
| Freckles Center-Mass Interception β Full Geometric Alignment Battery |
| ====================================================================== |
| Intercept every stage of the PatchSVAE pipeline and measure everything. |
| |
| Pipeline stages intercepted: |
| 1. Raw patches (B, N, 48) |
| 2. Encoder hidden states (B*N, 384) per block |
| 3. Pre-SVD matrix M (B, N, 48, 4) |
| 4. SVD components: U, S_orig, Vt (the bottleneck) |
| 5. Cross-attention: S_in β S_out (the coordination) |
| 6. Attention weights (B, 2, N, N) per layer |
| 7. Reconstructed matrix M_hat (B*N, 192) |
| 8. Decoder hidden states (B*N, 384) per block |
| 9. Decoded patches (B, N, 48) |
| 10. Stitched + boundary smooth (B, 3, H, W) |
| |
| Metrics at each stage: |
| - Effective rank, condition number |
| - Singular value spectrum |
| - Spearman rank correlation with input/output |
| - Procrustes alignment between stages |
| - CV (coefficient of variation of volumes) |
| - Cosine similarity distributions |
| - Gradient magnitude (if training) |
| - Information retention ratio |
| - Dead neuron / activation statistics |
| |
| Usage: |
| python freckles_observer.py --model v40_freckles_noise |
| """ |
|
|
| import os |
| import math |
| import json |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from collections import OrderedDict |
| from scipy import stats as scipy_stats |
|
|
|
|
| |
| |
| |
|
|
| class PipelineInterceptor: |
| """Attach hooks to every stage of PatchSVAE, capture activations.""" |
|
|
| def __init__(self, model): |
| self.model = model |
| self.captures = OrderedDict() |
| self.hooks = [] |
| self._attach_hooks() |
|
|
| def _attach_hooks(self): |
| m = self.model |
|
|
| |
| def hook_enc_in(module, inp, out): |
| self.captures['enc_in'] = out.detach() |
| self.hooks.append(m.enc_in.register_forward_hook(hook_enc_in)) |
|
|
| |
| for i, block in enumerate(m.enc_blocks): |
| def make_hook(idx): |
| def hook(module, inp, out): |
| self.captures[f'enc_block_{idx}'] = out.detach() |
| return hook |
| self.hooks.append(block.register_forward_hook(make_hook(i))) |
|
|
| |
| def hook_enc_out(module, inp, out): |
| self.captures['enc_out_raw'] = out.detach() |
| self.hooks.append(m.enc_out.register_forward_hook(hook_enc_out)) |
|
|
| |
| for i, layer in enumerate(m.cross_attn): |
| def make_ca_hook(idx): |
| def hook(module, inp, out): |
| self.captures[f'cross_attn_{idx}_in'] = inp[0].detach() |
| self.captures[f'cross_attn_{idx}_out'] = out.detach() |
| return hook |
| self.hooks.append(layer.register_forward_hook(make_ca_hook(i))) |
|
|
| |
| def make_qkv_hook(idx): |
| def hook(module, inp, out): |
| self.captures[f'cross_attn_{idx}_qkv'] = out.detach() |
| return hook |
| self.hooks.append(layer.qkv.register_forward_hook(make_qkv_hook(i))) |
|
|
| |
| def hook_dec_in(module, inp, out): |
| self.captures['dec_in'] = out.detach() |
| self.hooks.append(m.dec_in.register_forward_hook(hook_dec_in)) |
|
|
| |
| for i, block in enumerate(m.dec_blocks): |
| def make_hook(idx): |
| def hook(module, inp, out): |
| self.captures[f'dec_block_{idx}'] = out.detach() |
| return hook |
| self.hooks.append(block.register_forward_hook(make_hook(i))) |
|
|
| |
| def hook_dec_out(module, inp, out): |
| self.captures['dec_out'] = out.detach() |
| self.hooks.append(m.dec_out.register_forward_hook(hook_dec_out)) |
|
|
| |
| def hook_boundary(module, inp, out): |
| self.captures['boundary_in'] = inp[0].detach() |
| self.captures['boundary_out'] = out.detach() |
| self.hooks.append(m.boundary_smooth.register_forward_hook(hook_boundary)) |
|
|
| def clear(self): |
| self.captures.clear() |
|
|
| def remove_hooks(self): |
| for h in self.hooks: |
| h.remove() |
| self.hooks.clear() |
|
|
| @torch.no_grad() |
| def run(self, images): |
| """Forward pass with full interception.""" |
| self.clear() |
| out = self.model(images) |
|
|
| |
| self.captures['svd_U'] = out['svd']['U'].detach() |
| self.captures['svd_S_orig'] = out['svd']['S_orig'].detach() |
| self.captures['svd_S'] = out['svd']['S'].detach() |
| self.captures['svd_Vt'] = out['svd']['Vt'].detach() |
| self.captures['svd_M'] = out['svd']['M'].detach() |
| self.captures['recon'] = out['recon'].detach() |
| self.captures['input'] = images.detach() |
|
|
| return out |
|
|
|
|
| |
| |
| |
|
|
| def effective_rank(X): |
| """Effective rank of a matrix via Shannon entropy of singular values.""" |
| if X.dim() == 1: |
| X = X.unsqueeze(0) |
| if X.dim() == 2: |
| _, S, _ = torch.linalg.svd(X.float(), full_matrices=False) |
| else: |
| S = X.float() |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| return (-(p * p.log()).sum(-1)).exp() |
|
|
|
|
| def condition_number(X): |
| """Condition number from singular values.""" |
| if X.dim() < 2: |
| return torch.tensor(1.0) |
| try: |
| S = torch.linalg.svdvals(X.float()) |
| return (S[..., 0] / (S[..., -1] + 1e-10)).mean().item() |
| except: |
| return float('nan') |
|
|
|
|
| def spearman_rank(a, b): |
| """Spearman rank correlation between two flattened tensors.""" |
| a_np = a.flatten().cpu().numpy() |
| b_np = b.flatten().cpu().numpy() |
| n = min(len(a_np), len(b_np), 100000) |
| if n < len(a_np): |
| a_np = a_np[:n] |
| if n < len(b_np): |
| b_np = b_np[:n] |
| try: |
| r, p = scipy_stats.spearmanr(a_np, b_np) |
| return r |
| except: |
| return float('nan') |
|
|
|
|
| def procrustes_alignment(A, B): |
| """Procrustes alignment score between two (N, D) matrices. |
| Returns: rotation error (lower = more aligned), scale ratio. |
| """ |
| if A.shape != B.shape: |
| n = min(A.shape[0], B.shape[0]) |
| A, B = A[:n], B[:n] |
|
|
| A_c = A - A.mean(0, keepdim=True) |
| B_c = B - B.mean(0, keepdim=True) |
|
|
| A_n = A_c / (A_c.norm() + 1e-8) |
| B_n = B_c / (B_c.norm() + 1e-8) |
|
|
| M = A_n.T @ B_n |
| U, S, Vt = torch.linalg.svd(M.float()) |
| R = U @ Vt |
|
|
| aligned = B_n @ R.T |
| error = (A_n - aligned).pow(2).mean().item() |
| alignment = S.sum().item() |
|
|
| return {'error': error, 'alignment': alignment, 'scale_ratio': (A_c.norm() / (B_c.norm() + 1e-8)).item()} |
|
|
|
|
| def cosine_sim_distribution(X): |
| """Pairwise cosine similarity statistics for (N, D) matrix.""" |
| X_n = F.normalize(X.float(), dim=-1) |
| n = min(X_n.shape[0], 500) |
| X_n = X_n[:n] |
| sim = X_n @ X_n.T |
| mask = ~torch.eye(n, dtype=torch.bool, device=sim.device) |
| vals = sim[mask] |
| return { |
| 'mean': vals.mean().item(), |
| 'std': vals.std().item(), |
| 'min': vals.min().item(), |
| 'max': vals.max().item(), |
| 'median': vals.median().item(), |
| } |
|
|
|
|
| def activation_stats(X): |
| """Activation statistics for a hidden state tensor.""" |
| X_flat = X.float().flatten() |
| return { |
| 'mean': X_flat.mean().item(), |
| 'std': X_flat.std().item(), |
| 'min': X_flat.min().item(), |
| 'max': X_flat.max().item(), |
| 'abs_mean': X_flat.abs().mean().item(), |
| 'dead_frac': (X_flat.abs() < 1e-6).float().mean().item(), |
| 'sparsity': (X_flat == 0).float().mean().item(), |
| 'kurtosis': ((X_flat - X_flat.mean()) / (X_flat.std() + 1e-8)).pow(4).mean().item() - 3, |
| } |
|
|
|
|
| def cayley_menger_vol2(points): |
| B, N, D = points.shape |
| pts = points.double() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) |
| cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_of(emb, n_samples=200): |
| if emb.dim() != 2 or emb.shape[0] < 5: |
| return float('nan') |
| N, D = emb.shape |
| pool = min(N, 512) |
| indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] |
| for _ in range(n_samples)]) |
| vol2 = cayley_menger_vol2(emb[:pool][indices]) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return float('nan') |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| |
| |
| |
|
|
| def analyze_attention(qkv, n_heads, D): |
| """Analyze attention patterns from QKV tensor.""" |
| |
| if qkv.dim() == 3: |
| qkv = qkv.reshape(-1, qkv.shape[-1]) |
| BN, three_D = qkv.shape |
| |
| |
| head_dim = D // n_heads |
|
|
| qkv_r = qkv.reshape(-1, 3, D) |
| q, k, v = qkv_r[:, 0], qkv_r[:, 1], qkv_r[:, 2] |
|
|
| return { |
| 'q_norm_mean': q.norm(dim=-1).mean().item(), |
| 'k_norm_mean': k.norm(dim=-1).mean().item(), |
| 'v_norm_mean': v.norm(dim=-1).mean().item(), |
| 'q_std': q.std().item(), |
| 'k_std': k.std().item(), |
| 'v_std': v.std().item(), |
| 'qk_cosine': F.cosine_similarity(q, k, dim=-1).mean().item(), |
| 'qv_cosine': F.cosine_similarity(q, v, dim=-1).mean().item(), |
| 'kv_cosine': F.cosine_similarity(k, v, dim=-1).mean().item(), |
| } |
|
|
|
|
| def analyze_cross_attn_delta(S_in, S_out, alpha_logits, max_alpha=0.2): |
| """Analyze what cross-attention actually changed.""" |
| delta = S_out - S_in |
| alpha = max_alpha * torch.sigmoid(alpha_logits) |
|
|
| return { |
| 'delta_abs_mean': delta.abs().mean().item(), |
| 'delta_abs_max': delta.abs().max().item(), |
| 'delta_std': delta.std().item(), |
| 'delta_per_mode': delta.abs().mean(dim=(0, 1)).tolist(), |
| 'alpha_values': alpha.tolist(), |
| 'alpha_mean': alpha.mean().item(), |
| 'relative_change': (delta.abs().mean() / (S_in.abs().mean() + 1e-8)).item(), |
| 'sign_agreement': (delta.sign() == S_in.sign()).float().mean().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def analyze_svd_bottleneck(U, S, Vt, M): |
| """Deep analysis of the SVD bottleneck.""" |
| B, N, V, D = U.shape |
|
|
| |
| S_flat = S.reshape(-1, D) |
| S_mean = S_flat.mean(0) |
| S_std = S_flat.std(0) |
|
|
| |
| M_recon = torch.bmm( |
| U.reshape(B*N, V, D) * S.reshape(B*N, D).unsqueeze(1), |
| Vt.reshape(B*N, D, D)) |
| M_flat = M.reshape(B*N, V, D) |
| recon_error = (M_recon - M_flat).pow(2).mean().item() |
|
|
| |
| U_flat = U.reshape(B*N, V, D) |
| UtU = torch.bmm(U_flat.transpose(1, 2), U_flat) |
| eye = torch.eye(D, device=U.device).unsqueeze(0) |
| ortho_error = (UtU - eye).pow(2).mean().item() |
|
|
| |
| Vt_flat = Vt.reshape(B*N, D, D) |
| VtVt = torch.bmm(Vt_flat, Vt_flat.transpose(1, 2)) |
| vt_ortho_error = (VtVt - eye).pow(2).mean().item() |
|
|
| |
| cond = (S_mean[0] / (S_mean[-1] + 1e-8)).item() |
|
|
| |
| erank = effective_rank(S_flat).mean().item() |
|
|
| |
| S2 = S_flat.pow(2) |
| energy = S2 / (S2.sum(-1, keepdim=True) + 1e-8) |
| energy_mean = energy.mean(0) |
|
|
| |
| M_norms = M_flat.reshape(B*N*V, D).norm(dim=-1) |
|
|
| return { |
| 'S_mean': S_mean.tolist(), |
| 'S_std': S_std.tolist(), |
| 'S_ratio': (S_mean[0] / (S_mean[-1] + 1e-8)).item(), |
| 'condition_number': cond, |
| 'effective_rank': erank, |
| 'recon_error': recon_error, |
| 'U_orthogonality_error': ortho_error, |
| 'Vt_orthogonality_error': vt_ortho_error, |
| 'energy_per_mode': energy_mean.tolist(), |
| 'sphere_radius_mean': M_norms.mean().item(), |
| 'sphere_radius_std': M_norms.std().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def analyze_enc_dec_symmetry(interceptor): |
| """Compare encoder and decoder at each stage.""" |
| caps = interceptor.captures |
| results = {} |
| model = interceptor.model |
| depth = len(model.enc_blocks) |
|
|
| for i in range(depth): |
| enc_key = f'enc_block_{i}' |
| dec_key = f'dec_block_{depth - 1 - i}' |
|
|
| if enc_key in caps and dec_key in caps: |
| enc = caps[enc_key].reshape(-1, caps[enc_key].shape[-1]) |
| dec = caps[dec_key].reshape(-1, caps[dec_key].shape[-1]) |
|
|
| n = min(enc.shape[0], dec.shape[0], 10000) |
| enc_s, dec_s = enc[:n], dec[:n] |
|
|
| results[f'block_{i}_spearman'] = spearman_rank(enc_s, dec_s) |
| results[f'block_{i}_cosine'] = F.cosine_similarity( |
| enc_s.mean(0, keepdim=True), dec_s.mean(0, keepdim=True)).item() |
|
|
| proc = procrustes_alignment(enc_s[:500].cpu(), dec_s[:500].cpu()) |
| results[f'block_{i}_procrustes_error'] = proc['error'] |
| results[f'block_{i}_procrustes_alignment'] = proc['alignment'] |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def analyze_information_flow(interceptor): |
| """Track how information transforms through the pipeline.""" |
| caps = interceptor.captures |
| results = {} |
|
|
| |
| if 'input' in caps and 'enc_out_raw' in caps: |
| inp = caps['input'].reshape(caps['input'].shape[0], -1) |
| enc = caps['enc_out_raw'].reshape(caps['enc_out_raw'].shape[0], -1) |
| n = min(inp.shape[0], enc.shape[0]) |
| results['input_to_enc_spearman'] = spearman_rank(inp[:n], enc[:n]) |
|
|
| |
| if 'svd_M' in caps and 'svd_S_orig' in caps: |
| M = caps['svd_M'].reshape(-1, caps['svd_M'].shape[-1]) |
| S = caps['svd_S_orig'].reshape(-1, caps['svd_S_orig'].shape[-1]) |
| results['M_to_S_compression'] = M.shape[-1] / S.shape[-1] |
| |
| M_var = M.var().item() |
| S_var = S.var().item() |
| results['M_variance'] = M_var |
| results['S_variance'] = S_var |
| results['variance_retention'] = S_var / (M_var + 1e-8) |
|
|
| |
| if 'svd_S_orig' in caps and 'svd_S' in caps: |
| S_orig = caps['svd_S_orig'] |
| S_coord = caps['svd_S'] |
| delta = (S_coord - S_orig).abs() |
| results['cross_attn_total_delta'] = delta.mean().item() |
| results['cross_attn_max_delta'] = delta.max().item() |
| results['cross_attn_relative_delta'] = (delta.mean() / (S_orig.abs().mean() + 1e-8)).item() |
|
|
| |
| if 'dec_out' in caps and 'recon' in caps: |
| dec = caps['dec_out'].reshape(-1) |
| recon = caps['recon'].reshape(-1) |
| n = min(len(dec), len(recon), 100000) |
| results['dec_to_recon_spearman'] = spearman_rank(dec[:n], recon[:n]) |
|
|
| |
| if 'input' in caps and 'recon' in caps: |
| inp = caps['input'].reshape(-1) |
| recon = caps['recon'].reshape(-1) |
| n = min(len(inp), len(recon), 100000) |
| results['end_to_end_spearman'] = spearman_rank(inp[:n], recon[:n]) |
| results['end_to_end_mse'] = F.mse_loss( |
| caps['recon'], caps['input']).item() |
|
|
| |
| if 'boundary_in' in caps and 'boundary_out' in caps: |
| b_in = caps['boundary_in'] |
| b_out = caps['boundary_out'] |
| b_delta = (b_out - b_in).abs() |
| results['boundary_delta_mean'] = b_delta.mean().item() |
| results['boundary_delta_max'] = b_delta.max().item() |
| results['boundary_relative'] = (b_delta.mean() / (b_in.abs().mean() + 1e-8)).item() |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def analyze_cv_at_stages(interceptor): |
| """Compute CV at key representation stages.""" |
| caps = interceptor.captures |
| results = {} |
|
|
| stages = [ |
| ('enc_in', 'Encoder input projection'), |
| ('enc_block_0', 'Encoder block 0'), |
| ('enc_block_1', 'Encoder block 1'), |
| ('enc_block_2', 'Encoder block 2'), |
| ('enc_block_3', 'Encoder block 3'), |
| ('svd_S_orig', 'SVD S (pre cross-attn)'), |
| ('svd_S', 'SVD S (post cross-attn)'), |
| ('dec_in', 'Decoder input projection'), |
| ('dec_block_0', 'Decoder block 0'), |
| ('dec_block_3', 'Decoder block 3'), |
| ] |
|
|
| for key, name in stages: |
| if key not in caps: |
| continue |
| X = caps[key] |
| |
| if X.dim() == 4: |
| X = X.reshape(-1, X.shape[-1]) |
| elif X.dim() == 3: |
| X = X.reshape(-1, X.shape[-1]) |
| elif X.dim() == 2: |
| pass |
| else: |
| continue |
|
|
| if X.shape[-1] > 50: |
| |
| X_c = X - X.mean(0, keepdim=True) |
| _, _, V = torch.linalg.svd(X_c[:min(1000, len(X_c))].float(), full_matrices=False) |
| X = X @ V[:16].T |
|
|
| try: |
| cv = cv_of(X[:500].float()) |
| results[key] = {'cv': cv, 'name': name, 'dim': X.shape[-1]} |
| except: |
| results[key] = {'cv': float('nan'), 'name': name, 'dim': X.shape[-1]} |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def analyze_weights(model): |
| """Analyze model weight properties.""" |
| results = {} |
|
|
| for name, param in model.named_parameters(): |
| p = param.data.float() |
| entry = { |
| 'shape': list(p.shape), |
| 'norm': p.norm().item(), |
| 'mean': p.mean().item(), |
| 'std': p.std().item(), |
| 'abs_mean': p.abs().mean().item(), |
| 'sparsity': (p.abs() < 1e-6).float().mean().item(), |
| } |
| if p.dim() == 2: |
| entry['condition'] = condition_number(p) |
| entry['erank'] = effective_rank(p).item() |
| results[name] = entry |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def _gen_noise(t, s): |
| """Minimal noise gen for fingerprint analysis.""" |
| if t == 0: return torch.randn(3, s, s) |
| elif t == 1: return torch.rand(3, s, s) * 2 - 1 |
| elif t == 4: |
| w = torch.randn(3, s, s) |
| S = torch.fft.rfft2(w) |
| h, ww = s, s |
| fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) |
| fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) |
| img = torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) |
| return img / (img.std() + 1e-8) |
| elif t == 6: |
| return torch.where(torch.rand(3, s, s) > 0.5, |
| torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1 |
| elif t == 13: |
| return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3) |
| return torch.randn(3, s, s) |
|
|
|
|
| def analyze_noise_fingerprints(interceptor, device, n_per=16): |
| """How do different noise types look at each pipeline stage?""" |
| results = {} |
| type_names = {0: 'gaussian', 1: 'uniform', 4: 'pink', 6: 'salt_pepper', 13: 'cauchy'} |
|
|
| for t, name in type_names.items(): |
| imgs = torch.stack([_gen_noise(t, 64).clamp(-4, 4) for _ in range(n_per)]).to(device) |
| interceptor.run(imgs) |
| caps = interceptor.captures |
|
|
| entry = {} |
| |
| if 'svd_S' in caps: |
| S = caps['svd_S'] |
| entry['S_mean'] = S.mean(dim=(0, 1)).tolist() |
| entry['S_std'] = S.std(dim=(0, 1)).tolist() |
| entry['erank'] = effective_rank(S.reshape(-1, S.shape[-1])).mean().item() |
|
|
| |
| if 'enc_block_3' in caps: |
| h = caps['enc_block_3'] |
| entry['enc_final_abs_mean'] = h.abs().mean().item() |
| entry['enc_final_dead_frac'] = (h.abs() < 1e-6).float().mean().item() |
|
|
| |
| if 'svd_S_orig' in caps and 'svd_S' in caps: |
| delta = (caps['svd_S'] - caps['svd_S_orig']).abs() |
| entry['cross_attn_delta'] = delta.mean().item() |
|
|
| |
| if 'recon' in caps: |
| entry['recon_mse'] = F.mse_loss(caps['recon'], imgs).item() |
|
|
| results[name] = entry |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def run_full_battery(model, device, img_size=64, n_samples=64): |
| """Run the complete center-mass interception battery.""" |
|
|
| print("\n" + "=" * 70) |
| print("FRECKLES CENTER-MASS INTERCEPTION") |
| print("Full Geometric Alignment Battery") |
| print("=" * 70) |
|
|
| interceptor = PipelineInterceptor(model) |
| D = model.D |
| ps = model.patch_size |
|
|
| |
| imgs = [] |
| for t in range(16): |
| for _ in range(n_samples // 16): |
| imgs.append(_gen_noise(t % 5 * 3, img_size).clamp(-4, 4)) |
| imgs = torch.stack(imgs[:n_samples]).to(device) |
|
|
| t0 = time.time() |
| all_results = {} |
|
|
| |
| print("\n [1/8] Forward pass with interception...") |
| out = interceptor.run(imgs) |
| print(f" Captured {len(interceptor.captures)} activation tensors") |
| for k, v in interceptor.captures.items(): |
| print(f" {k:30s} {str(list(v.shape)):>30s}") |
|
|
| |
| print("\n [2/8] Activation statistics...") |
| act_stats = {} |
| for key, tensor in interceptor.captures.items(): |
| if tensor.numel() > 0: |
| act_stats[key] = activation_stats(tensor) |
| all_results['activation_stats'] = act_stats |
|
|
| for key in ['enc_in', 'enc_block_3', 'svd_S_orig', 'svd_S', 'dec_in', 'dec_block_3', 'dec_out']: |
| if key in act_stats: |
| s = act_stats[key] |
| print(f" {key:25s} mean={s['mean']:+.4f} std={s['std']:.4f} " |
| f"dead={s['dead_frac']:.3f} kurt={s['kurtosis']:.2f}") |
|
|
| |
| print("\n [3/8] SVD bottleneck analysis...") |
| caps = interceptor.captures |
| svd_analysis = analyze_svd_bottleneck( |
| caps['svd_U'], caps['svd_S_orig'], caps['svd_Vt'], caps['svd_M']) |
| all_results['svd_bottleneck'] = svd_analysis |
| print(f" S spectrum: {['%.3f' % s for s in svd_analysis['S_mean']]}") |
| print(f" S ratio (S0/SD): {svd_analysis['S_ratio']:.3f}") |
| print(f" Effective rank: {svd_analysis['effective_rank']:.3f}") |
| print(f" U orthogonality error: {svd_analysis['U_orthogonality_error']:.6f}") |
| print(f" Vt orthogonality error: {svd_analysis['Vt_orthogonality_error']:.6f}") |
| print(f" Energy per mode: {['%.3f' % e for e in svd_analysis['energy_per_mode']]}") |
| print(f" Sphere radius: {svd_analysis['sphere_radius_mean']:.4f} Β± {svd_analysis['sphere_radius_std']:.4f}") |
|
|
| |
| print("\n [4/8] Cross-attention analysis...") |
| ca_results = {} |
| for i in range(len(model.cross_attn)): |
| in_key = f'cross_attn_{i}_in' |
| out_key = f'cross_attn_{i}_out' |
| qkv_key = f'cross_attn_{i}_qkv' |
|
|
| if in_key in caps and out_key in caps: |
| delta = analyze_cross_attn_delta( |
| caps[in_key], caps[out_key], |
| model.cross_attn[i].alpha_logits) |
| ca_results[f'layer_{i}_delta'] = delta |
| print(f" Layer {i}: delta={delta['delta_abs_mean']:.6f} " |
| f"relative={delta['relative_change']:.4f} " |
| f"alpha={delta['alpha_values']}") |
|
|
| if qkv_key in caps: |
| attn = analyze_attention(caps[qkv_key], model.cross_attn[i].n_heads, D) |
| ca_results[f'layer_{i}_attention'] = attn |
| print(f" Layer {i} QKV: q_norm={attn['q_norm_mean']:.3f} " |
| f"qk_cos={attn['qk_cosine']:.3f} " |
| f"kv_cos={attn['kv_cosine']:.3f}") |
|
|
| all_results['cross_attention'] = ca_results |
|
|
| |
| print("\n [5/8] Encoder/decoder symmetry...") |
| sym = analyze_enc_dec_symmetry(interceptor) |
| all_results['enc_dec_symmetry'] = sym |
| for key, val in sorted(sym.items()): |
| print(f" {key}: {val:.4f}") |
|
|
| |
| print("\n [6/8] Information flow analysis...") |
| flow = analyze_information_flow(interceptor) |
| all_results['information_flow'] = flow |
| for key, val in sorted(flow.items()): |
| print(f" {key}: {val:.6f}") |
|
|
| |
| print("\n [7/8] CV at pipeline stages...") |
| cv_stages = analyze_cv_at_stages(interceptor) |
| all_results['cv_stages'] = cv_stages |
| for key, data in cv_stages.items(): |
| print(f" {data['name']:30s} CV={data['cv']:.4f} dim={data['dim']}") |
|
|
| |
| print("\n [8/8] Noise type fingerprints...") |
| fingerprints = analyze_noise_fingerprints(interceptor, device) |
| all_results['noise_fingerprints'] = fingerprints |
| for name, fp in fingerprints.items(): |
| print(f" {name:15s} S={['%.2f' % s for s in fp.get('S_mean', [])]}" |
| f" er={fp.get('erank', 0):.2f}" |
| f" Ξca={fp.get('cross_attn_delta', 0):.5f}" |
| f" mse={fp.get('recon_mse', 0):.6f}") |
|
|
| |
| print("\n [BONUS] Weight analysis...") |
| weights = analyze_weights(model) |
| all_results['weights'] = weights |
| total_params = sum(p.numel() for p in model.parameters()) |
| total_sparse = sum(v['sparsity'] * np.prod(v['shape']) for v in weights.values()) |
| print(f" Total params: {total_params:,}") |
| print(f" Effective sparsity: {total_sparse / total_params:.4f}") |
|
|
| |
| for name in ['enc_out.weight', 'dec_in.weight', 'dec_out.weight']: |
| if name in weights: |
| w = weights[name] |
| print(f" {name:25s} norm={w['norm']:.3f} cond={w.get('condition', 'N/A')}" |
| f" erank={w.get('erank', 'N/A')}") |
|
|
| elapsed = time.time() - t0 |
| interceptor.remove_hooks() |
|
|
| print(f"\n{'=' * 70}") |
| print(f"BATTERY COMPLETE β {elapsed:.1f}s") |
| print(f"{'=' * 70}") |
|
|
| return all_results |
|
|
|
|
| |
| |
| |
|
|
| def load_freckles(model_path=None, hf_version=None, device='cuda'): |
| from geolip_svae import load_model |
| if hf_version: |
| return load_model(hf_version=hf_version, device=device) |
| else: |
| return load_model(checkpoint_path=model_path, device=device) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| MODEL = 'v41_freckles_256' |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model, cfg = load_freckles(hf_version=MODEL, device=device) |
|
|
| results = run_full_battery(model, device, img_size=cfg.get('img_size', 64)) |
|
|
| |
| def to_json(obj): |
| if isinstance(obj, (torch.Tensor, np.ndarray)): |
| if hasattr(obj, 'tolist'): |
| return obj.tolist() |
| return float(obj) |
| if isinstance(obj, dict): |
| return {str(k): to_json(v) for k, v in obj.items()} |
| if isinstance(obj, (list, tuple)): |
| return [to_json(v) for v in obj] |
| if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): |
| return str(obj) |
| return obj |
|
|
| out_path = f'freckles_observer_{MODEL}.json' |
| with open(out_path, 'w') as f: |
| json.dump(to_json(results), f, indent=2) |
| print(f" Saved: {out_path}") |