geolip-SVAE / v41_freckles_256 /freckles_battery_test.py
AbstractPhil's picture
Create freckles_battery_test.py
492b1a0 verified
raw
history blame
31.9 kB
"""
Freckles Center-Mass Interception β€” Full Geometric Alignment Battery
======================================================================
Intercept every stage of the PatchSVAE pipeline and measure everything.
Pipeline stages intercepted:
1. Raw patches (B, N, 48)
2. Encoder hidden states (B*N, 384) per block
3. Pre-SVD matrix M (B, N, 48, 4)
4. SVD components: U, S_orig, Vt (the bottleneck)
5. Cross-attention: S_in β†’ S_out (the coordination)
6. Attention weights (B, 2, N, N) per layer
7. Reconstructed matrix M_hat (B*N, 192)
8. Decoder hidden states (B*N, 384) per block
9. Decoded patches (B, N, 48)
10. Stitched + boundary smooth (B, 3, H, W)
Metrics at each stage:
- Effective rank, condition number
- Singular value spectrum
- Spearman rank correlation with input/output
- Procrustes alignment between stages
- CV (coefficient of variation of volumes)
- Cosine similarity distributions
- Gradient magnitude (if training)
- Information retention ratio
- Dead neuron / activation statistics
Usage:
python freckles_observer.py --model v40_freckles_noise
"""
import os
import math
import json
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from scipy import stats as scipy_stats
# ═══════════════════════════════════════════════════════════════
# HOOK-BASED INTERCEPTOR
# ═══════════════════════════════════════════════════════════════
class PipelineInterceptor:
"""Attach hooks to every stage of PatchSVAE, capture activations."""
def __init__(self, model):
self.model = model
self.captures = OrderedDict()
self.hooks = []
self._attach_hooks()
def _attach_hooks(self):
m = self.model
# Encoder input projection
def hook_enc_in(module, inp, out):
self.captures['enc_in'] = out.detach()
self.hooks.append(m.enc_in.register_forward_hook(hook_enc_in))
# Encoder blocks
for i, block in enumerate(m.enc_blocks):
def make_hook(idx):
def hook(module, inp, out):
self.captures[f'enc_block_{idx}'] = out.detach()
return hook
self.hooks.append(block.register_forward_hook(make_hook(i)))
# Encoder output (pre-normalize, pre-SVD)
def hook_enc_out(module, inp, out):
self.captures['enc_out_raw'] = out.detach()
self.hooks.append(m.enc_out.register_forward_hook(hook_enc_out))
# Cross-attention layers
for i, layer in enumerate(m.cross_attn):
def make_ca_hook(idx):
def hook(module, inp, out):
self.captures[f'cross_attn_{idx}_in'] = inp[0].detach()
self.captures[f'cross_attn_{idx}_out'] = out.detach()
return hook
self.hooks.append(layer.register_forward_hook(make_ca_hook(i)))
# QKV hook for attention weights
def make_qkv_hook(idx):
def hook(module, inp, out):
self.captures[f'cross_attn_{idx}_qkv'] = out.detach()
return hook
self.hooks.append(layer.qkv.register_forward_hook(make_qkv_hook(i)))
# Decoder input
def hook_dec_in(module, inp, out):
self.captures['dec_in'] = out.detach()
self.hooks.append(m.dec_in.register_forward_hook(hook_dec_in))
# Decoder blocks
for i, block in enumerate(m.dec_blocks):
def make_hook(idx):
def hook(module, inp, out):
self.captures[f'dec_block_{idx}'] = out.detach()
return hook
self.hooks.append(block.register_forward_hook(make_hook(i)))
# Decoder output
def hook_dec_out(module, inp, out):
self.captures['dec_out'] = out.detach()
self.hooks.append(m.dec_out.register_forward_hook(hook_dec_out))
# Boundary smooth
def hook_boundary(module, inp, out):
self.captures['boundary_in'] = inp[0].detach()
self.captures['boundary_out'] = out.detach()
self.hooks.append(m.boundary_smooth.register_forward_hook(hook_boundary))
def clear(self):
self.captures.clear()
def remove_hooks(self):
for h in self.hooks:
h.remove()
self.hooks.clear()
@torch.no_grad()
def run(self, images):
"""Forward pass with full interception."""
self.clear()
out = self.model(images)
# Also capture SVD components from the output
self.captures['svd_U'] = out['svd']['U'].detach()
self.captures['svd_S_orig'] = out['svd']['S_orig'].detach()
self.captures['svd_S'] = out['svd']['S'].detach()
self.captures['svd_Vt'] = out['svd']['Vt'].detach()
self.captures['svd_M'] = out['svd']['M'].detach()
self.captures['recon'] = out['recon'].detach()
self.captures['input'] = images.detach()
return out
# ═══════════════════════════════════════════════════════════════
# GEOMETRIC METRICS
# ═══════════════════════════════════════════════════════════════
def effective_rank(X):
"""Effective rank of a matrix via Shannon entropy of singular values."""
if X.dim() == 1:
X = X.unsqueeze(0)
if X.dim() == 2:
_, S, _ = torch.linalg.svd(X.float(), full_matrices=False)
else:
S = X.float()
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
return (-(p * p.log()).sum(-1)).exp()
def condition_number(X):
"""Condition number from singular values."""
if X.dim() < 2:
return torch.tensor(1.0)
try:
S = torch.linalg.svdvals(X.float())
return (S[..., 0] / (S[..., -1] + 1e-10)).mean().item()
except:
return float('nan')
def spearman_rank(a, b):
"""Spearman rank correlation between two flattened tensors."""
a_np = a.flatten().cpu().numpy()
b_np = b.flatten().cpu().numpy()
n = min(len(a_np), len(b_np), 100000)
if n < len(a_np):
a_np = a_np[:n]
if n < len(b_np):
b_np = b_np[:n]
try:
r, p = scipy_stats.spearmanr(a_np, b_np)
return r
except:
return float('nan')
def procrustes_alignment(A, B):
"""Procrustes alignment score between two (N, D) matrices.
Returns: rotation error (lower = more aligned), scale ratio.
"""
if A.shape != B.shape:
n = min(A.shape[0], B.shape[0])
A, B = A[:n], B[:n]
A_c = A - A.mean(0, keepdim=True)
B_c = B - B.mean(0, keepdim=True)
A_n = A_c / (A_c.norm() + 1e-8)
B_n = B_c / (B_c.norm() + 1e-8)
M = A_n.T @ B_n
U, S, Vt = torch.linalg.svd(M.float())
R = U @ Vt
aligned = B_n @ R.T
error = (A_n - aligned).pow(2).mean().item()
alignment = S.sum().item()
return {'error': error, 'alignment': alignment, 'scale_ratio': (A_c.norm() / (B_c.norm() + 1e-8)).item()}
def cosine_sim_distribution(X):
"""Pairwise cosine similarity statistics for (N, D) matrix."""
X_n = F.normalize(X.float(), dim=-1)
n = min(X_n.shape[0], 500)
X_n = X_n[:n]
sim = X_n @ X_n.T
mask = ~torch.eye(n, dtype=torch.bool, device=sim.device)
vals = sim[mask]
return {
'mean': vals.mean().item(),
'std': vals.std().item(),
'min': vals.min().item(),
'max': vals.max().item(),
'median': vals.median().item(),
}
def activation_stats(X):
"""Activation statistics for a hidden state tensor."""
X_flat = X.float().flatten()
return {
'mean': X_flat.mean().item(),
'std': X_flat.std().item(),
'min': X_flat.min().item(),
'max': X_flat.max().item(),
'abs_mean': X_flat.abs().mean().item(),
'dead_frac': (X_flat.abs() < 1e-6).float().mean().item(),
'sparsity': (X_flat == 0).float().mean().item(),
'kurtosis': ((X_flat - X_flat.mean()) / (X_flat.std() + 1e-8)).pow(4).mean().item() - 3,
}
def cayley_menger_vol2(points):
B, N, D = points.shape
pts = points.double()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
def cv_of(emb, n_samples=200):
if emb.dim() != 2 or emb.shape[0] < 5:
return float('nan')
N, D = emb.shape
pool = min(N, 512)
indices = torch.stack([torch.randperm(pool, device=emb.device)[:5]
for _ in range(n_samples)])
vol2 = cayley_menger_vol2(emb[:pool][indices])
valid = vol2 > 1e-20
if valid.sum() < 10:
return float('nan')
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
# ═══════════════════════════════════════════════════════════════
# ATTENTION ANALYSIS
# ═══════════════════════════════════════════════════════════════
def analyze_attention(qkv, n_heads, D):
"""Analyze attention patterns from QKV tensor."""
# qkv shape varies: (B, N, 3*D) or (B*N, 3*D)
if qkv.dim() == 3:
qkv = qkv.reshape(-1, qkv.shape[-1])
BN, three_D = qkv.shape
# We need B and N β€” estimate from context
# For now analyze the raw QKV statistics
head_dim = D // n_heads
qkv_r = qkv.reshape(-1, 3, D)
q, k, v = qkv_r[:, 0], qkv_r[:, 1], qkv_r[:, 2]
return {
'q_norm_mean': q.norm(dim=-1).mean().item(),
'k_norm_mean': k.norm(dim=-1).mean().item(),
'v_norm_mean': v.norm(dim=-1).mean().item(),
'q_std': q.std().item(),
'k_std': k.std().item(),
'v_std': v.std().item(),
'qk_cosine': F.cosine_similarity(q, k, dim=-1).mean().item(),
'qv_cosine': F.cosine_similarity(q, v, dim=-1).mean().item(),
'kv_cosine': F.cosine_similarity(k, v, dim=-1).mean().item(),
}
def analyze_cross_attn_delta(S_in, S_out, alpha_logits, max_alpha=0.2):
"""Analyze what cross-attention actually changed."""
delta = S_out - S_in
alpha = max_alpha * torch.sigmoid(alpha_logits)
return {
'delta_abs_mean': delta.abs().mean().item(),
'delta_abs_max': delta.abs().max().item(),
'delta_std': delta.std().item(),
'delta_per_mode': delta.abs().mean(dim=(0, 1)).tolist(),
'alpha_values': alpha.tolist(),
'alpha_mean': alpha.mean().item(),
'relative_change': (delta.abs().mean() / (S_in.abs().mean() + 1e-8)).item(),
'sign_agreement': (delta.sign() == S_in.sign()).float().mean().item(),
}
# ═══════════════════════════════════════════════════════════════
# SVD BOTTLENECK ANALYSIS
# ═══════════════════════════════════════════════════════════════
def analyze_svd_bottleneck(U, S, Vt, M):
"""Deep analysis of the SVD bottleneck."""
B, N, V, D = U.shape
# Singular value spectrum
S_flat = S.reshape(-1, D)
S_mean = S_flat.mean(0)
S_std = S_flat.std(0)
# Reconstruction quality: M vs U @ diag(S) @ Vt
M_recon = torch.bmm(
U.reshape(B*N, V, D) * S.reshape(B*N, D).unsqueeze(1),
Vt.reshape(B*N, D, D))
M_flat = M.reshape(B*N, V, D)
recon_error = (M_recon - M_flat).pow(2).mean().item()
# Orthogonality of U columns
U_flat = U.reshape(B*N, V, D)
UtU = torch.bmm(U_flat.transpose(1, 2), U_flat)
eye = torch.eye(D, device=U.device).unsqueeze(0)
ortho_error = (UtU - eye).pow(2).mean().item()
# Orthogonality of Vt rows
Vt_flat = Vt.reshape(B*N, D, D)
VtVt = torch.bmm(Vt_flat, Vt_flat.transpose(1, 2))
vt_ortho_error = (VtVt - eye).pow(2).mean().item()
# Condition number of M
cond = (S_mean[0] / (S_mean[-1] + 1e-8)).item()
# Effective rank
erank = effective_rank(S_flat).mean().item()
# Energy distribution
S2 = S_flat.pow(2)
energy = S2 / (S2.sum(-1, keepdim=True) + 1e-8)
energy_mean = energy.mean(0)
# Sphere radius (norm of M rows after normalization)
M_norms = M_flat.reshape(B*N*V, D).norm(dim=-1)
return {
'S_mean': S_mean.tolist(),
'S_std': S_std.tolist(),
'S_ratio': (S_mean[0] / (S_mean[-1] + 1e-8)).item(),
'condition_number': cond,
'effective_rank': erank,
'recon_error': recon_error,
'U_orthogonality_error': ortho_error,
'Vt_orthogonality_error': vt_ortho_error,
'energy_per_mode': energy_mean.tolist(),
'sphere_radius_mean': M_norms.mean().item(),
'sphere_radius_std': M_norms.std().item(),
}
# ═══════════════════════════════════════════════════════════════
# ENCODER/DECODER SYMMETRY ANALYSIS
# ═══════════════════════════════════════════════════════════════
def analyze_enc_dec_symmetry(interceptor):
"""Compare encoder and decoder at each stage."""
caps = interceptor.captures
results = {}
model = interceptor.model
depth = len(model.enc_blocks)
for i in range(depth):
enc_key = f'enc_block_{i}'
dec_key = f'dec_block_{depth - 1 - i}' # Mirror order
if enc_key in caps and dec_key in caps:
enc = caps[enc_key].reshape(-1, caps[enc_key].shape[-1])
dec = caps[dec_key].reshape(-1, caps[dec_key].shape[-1])
n = min(enc.shape[0], dec.shape[0], 10000)
enc_s, dec_s = enc[:n], dec[:n]
results[f'block_{i}_spearman'] = spearman_rank(enc_s, dec_s)
results[f'block_{i}_cosine'] = F.cosine_similarity(
enc_s.mean(0, keepdim=True), dec_s.mean(0, keepdim=True)).item()
proc = procrustes_alignment(enc_s[:500].cpu(), dec_s[:500].cpu())
results[f'block_{i}_procrustes_error'] = proc['error']
results[f'block_{i}_procrustes_alignment'] = proc['alignment']
return results
# ═══════════════════════════════════════════════════════════════
# INFORMATION FLOW ANALYSIS
# ═══════════════════════════════════════════════════════════════
def analyze_information_flow(interceptor):
"""Track how information transforms through the pipeline."""
caps = interceptor.captures
results = {}
# Input β†’ Encoder output
if 'input' in caps and 'enc_out_raw' in caps:
inp = caps['input'].reshape(caps['input'].shape[0], -1)
enc = caps['enc_out_raw'].reshape(caps['enc_out_raw'].shape[0], -1)
n = min(inp.shape[0], enc.shape[0])
results['input_to_enc_spearman'] = spearman_rank(inp[:n], enc[:n])
# Pre-SVD M β†’ Post-SVD S
if 'svd_M' in caps and 'svd_S_orig' in caps:
M = caps['svd_M'].reshape(-1, caps['svd_M'].shape[-1])
S = caps['svd_S_orig'].reshape(-1, caps['svd_S_orig'].shape[-1])
results['M_to_S_compression'] = M.shape[-1] / S.shape[-1]
# How much variance does S capture of M?
M_var = M.var().item()
S_var = S.var().item()
results['M_variance'] = M_var
results['S_variance'] = S_var
results['variance_retention'] = S_var / (M_var + 1e-8)
# S_orig β†’ S (cross-attention effect)
if 'svd_S_orig' in caps and 'svd_S' in caps:
S_orig = caps['svd_S_orig']
S_coord = caps['svd_S']
delta = (S_coord - S_orig).abs()
results['cross_attn_total_delta'] = delta.mean().item()
results['cross_attn_max_delta'] = delta.max().item()
results['cross_attn_relative_delta'] = (delta.mean() / (S_orig.abs().mean() + 1e-8)).item()
# Decoder output β†’ Reconstruction
if 'dec_out' in caps and 'recon' in caps:
dec = caps['dec_out'].reshape(-1)
recon = caps['recon'].reshape(-1)
n = min(len(dec), len(recon), 100000)
results['dec_to_recon_spearman'] = spearman_rank(dec[:n], recon[:n])
# Input β†’ Reconstruction (end-to-end)
if 'input' in caps and 'recon' in caps:
inp = caps['input'].reshape(-1)
recon = caps['recon'].reshape(-1)
n = min(len(inp), len(recon), 100000)
results['end_to_end_spearman'] = spearman_rank(inp[:n], recon[:n])
results['end_to_end_mse'] = F.mse_loss(
caps['recon'], caps['input']).item()
# Boundary smooth effect
if 'boundary_in' in caps and 'boundary_out' in caps:
b_in = caps['boundary_in']
b_out = caps['boundary_out']
b_delta = (b_out - b_in).abs()
results['boundary_delta_mean'] = b_delta.mean().item()
results['boundary_delta_max'] = b_delta.max().item()
results['boundary_relative'] = (b_delta.mean() / (b_in.abs().mean() + 1e-8)).item()
return results
# ═══════════════════════════════════════════════════════════════
# CV AT EVERY STAGE
# ═══════════════════════════════════════════════════════════════
def analyze_cv_at_stages(interceptor):
"""Compute CV at key representation stages."""
caps = interceptor.captures
results = {}
stages = [
('enc_in', 'Encoder input projection'),
('enc_block_0', 'Encoder block 0'),
('enc_block_1', 'Encoder block 1'),
('enc_block_2', 'Encoder block 2'),
('enc_block_3', 'Encoder block 3'),
('svd_S_orig', 'SVD S (pre cross-attn)'),
('svd_S', 'SVD S (post cross-attn)'),
('dec_in', 'Decoder input projection'),
('dec_block_0', 'Decoder block 0'),
('dec_block_3', 'Decoder block 3'),
]
for key, name in stages:
if key not in caps:
continue
X = caps[key]
# Reshape to (N, D) for CV computation
if X.dim() == 4: # (B, N, V, D)
X = X.reshape(-1, X.shape[-1])
elif X.dim() == 3: # (B, N, D) or (B*N, ...)
X = X.reshape(-1, X.shape[-1])
elif X.dim() == 2:
pass
else:
continue
if X.shape[-1] > 50:
# PCA down for CV computation
X_c = X - X.mean(0, keepdim=True)
_, _, V = torch.linalg.svd(X_c[:min(1000, len(X_c))].float(), full_matrices=False)
X = X @ V[:16].T
try:
cv = cv_of(X[:500].float())
results[key] = {'cv': cv, 'name': name, 'dim': X.shape[-1]}
except:
results[key] = {'cv': float('nan'), 'name': name, 'dim': X.shape[-1]}
return results
# ═══════════════════════════════════════════════════════════════
# WEIGHT ANALYSIS
# ═══════════════════════════════════════════════════════════════
def analyze_weights(model):
"""Analyze model weight properties."""
results = {}
for name, param in model.named_parameters():
p = param.data.float()
entry = {
'shape': list(p.shape),
'norm': p.norm().item(),
'mean': p.mean().item(),
'std': p.std().item(),
'abs_mean': p.abs().mean().item(),
'sparsity': (p.abs() < 1e-6).float().mean().item(),
}
if p.dim() == 2:
entry['condition'] = condition_number(p)
entry['erank'] = effective_rank(p).item()
results[name] = entry
return results
# ═══════════════════════════════════════════════════════════════
# NOISE TYPE FINGERPRINT ANALYSIS
# ═══════════════════════════════════════════════════════════════
def _gen_noise(t, s):
"""Minimal noise gen for fingerprint analysis."""
if t == 0: return torch.randn(3, s, s)
elif t == 1: return torch.rand(3, s, s) * 2 - 1
elif t == 4:
w = torch.randn(3, s, s)
S = torch.fft.rfft2(w)
h, ww = s, s
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
img = torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
return img / (img.std() + 1e-8)
elif t == 6:
return torch.where(torch.rand(3, s, s) > 0.5,
torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1
elif t == 13:
return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3)
return torch.randn(3, s, s)
def analyze_noise_fingerprints(interceptor, device, n_per=16):
"""How do different noise types look at each pipeline stage?"""
results = {}
type_names = {0: 'gaussian', 1: 'uniform', 4: 'pink', 6: 'salt_pepper', 13: 'cauchy'}
for t, name in type_names.items():
imgs = torch.stack([_gen_noise(t, 64).clamp(-4, 4) for _ in range(n_per)]).to(device)
interceptor.run(imgs)
caps = interceptor.captures
entry = {}
# S profile per noise type
if 'svd_S' in caps:
S = caps['svd_S']
entry['S_mean'] = S.mean(dim=(0, 1)).tolist()
entry['S_std'] = S.std(dim=(0, 1)).tolist()
entry['erank'] = effective_rank(S.reshape(-1, S.shape[-1])).mean().item()
# Encoder hidden activation pattern
if 'enc_block_3' in caps:
h = caps['enc_block_3']
entry['enc_final_abs_mean'] = h.abs().mean().item()
entry['enc_final_dead_frac'] = (h.abs() < 1e-6).float().mean().item()
# Cross-attention delta
if 'svd_S_orig' in caps and 'svd_S' in caps:
delta = (caps['svd_S'] - caps['svd_S_orig']).abs()
entry['cross_attn_delta'] = delta.mean().item()
# Reconstruction MSE
if 'recon' in caps:
entry['recon_mse'] = F.mse_loss(caps['recon'], imgs).item()
results[name] = entry
return results
# ═══════════════════════════════════════════════════════════════
# FULL BATTERY
# ═══════════════════════════════════════════════════════════════
@torch.no_grad()
def run_full_battery(model, device, img_size=64, n_samples=64):
"""Run the complete center-mass interception battery."""
print("\n" + "=" * 70)
print("FRECKLES CENTER-MASS INTERCEPTION")
print("Full Geometric Alignment Battery")
print("=" * 70)
interceptor = PipelineInterceptor(model)
D = model.D
ps = model.patch_size
# Generate test data (mixed noise types)
imgs = []
for t in range(16):
for _ in range(n_samples // 16):
imgs.append(_gen_noise(t % 5 * 3, img_size).clamp(-4, 4))
imgs = torch.stack(imgs[:n_samples]).to(device)
t0 = time.time()
all_results = {}
# ── 1. Forward with interception ──
print("\n [1/8] Forward pass with interception...")
out = interceptor.run(imgs)
print(f" Captured {len(interceptor.captures)} activation tensors")
for k, v in interceptor.captures.items():
print(f" {k:30s} {str(list(v.shape)):>30s}")
# ── 2. Activation statistics at every stage ──
print("\n [2/8] Activation statistics...")
act_stats = {}
for key, tensor in interceptor.captures.items():
if tensor.numel() > 0:
act_stats[key] = activation_stats(tensor)
all_results['activation_stats'] = act_stats
for key in ['enc_in', 'enc_block_3', 'svd_S_orig', 'svd_S', 'dec_in', 'dec_block_3', 'dec_out']:
if key in act_stats:
s = act_stats[key]
print(f" {key:25s} mean={s['mean']:+.4f} std={s['std']:.4f} "
f"dead={s['dead_frac']:.3f} kurt={s['kurtosis']:.2f}")
# ── 3. SVD bottleneck analysis ──
print("\n [3/8] SVD bottleneck analysis...")
caps = interceptor.captures
svd_analysis = analyze_svd_bottleneck(
caps['svd_U'], caps['svd_S_orig'], caps['svd_Vt'], caps['svd_M'])
all_results['svd_bottleneck'] = svd_analysis
print(f" S spectrum: {['%.3f' % s for s in svd_analysis['S_mean']]}")
print(f" S ratio (S0/SD): {svd_analysis['S_ratio']:.3f}")
print(f" Effective rank: {svd_analysis['effective_rank']:.3f}")
print(f" U orthogonality error: {svd_analysis['U_orthogonality_error']:.6f}")
print(f" Vt orthogonality error: {svd_analysis['Vt_orthogonality_error']:.6f}")
print(f" Energy per mode: {['%.3f' % e for e in svd_analysis['energy_per_mode']]}")
print(f" Sphere radius: {svd_analysis['sphere_radius_mean']:.4f} Β± {svd_analysis['sphere_radius_std']:.4f}")
# ── 4. Cross-attention analysis ──
print("\n [4/8] Cross-attention analysis...")
ca_results = {}
for i in range(len(model.cross_attn)):
in_key = f'cross_attn_{i}_in'
out_key = f'cross_attn_{i}_out'
qkv_key = f'cross_attn_{i}_qkv'
if in_key in caps and out_key in caps:
delta = analyze_cross_attn_delta(
caps[in_key], caps[out_key],
model.cross_attn[i].alpha_logits)
ca_results[f'layer_{i}_delta'] = delta
print(f" Layer {i}: delta={delta['delta_abs_mean']:.6f} "
f"relative={delta['relative_change']:.4f} "
f"alpha={delta['alpha_values']}")
if qkv_key in caps:
attn = analyze_attention(caps[qkv_key], model.cross_attn[i].n_heads, D)
ca_results[f'layer_{i}_attention'] = attn
print(f" Layer {i} QKV: q_norm={attn['q_norm_mean']:.3f} "
f"qk_cos={attn['qk_cosine']:.3f} "
f"kv_cos={attn['kv_cosine']:.3f}")
all_results['cross_attention'] = ca_results
# ── 5. Encoder/Decoder symmetry ──
print("\n [5/8] Encoder/decoder symmetry...")
sym = analyze_enc_dec_symmetry(interceptor)
all_results['enc_dec_symmetry'] = sym
for key, val in sorted(sym.items()):
print(f" {key}: {val:.4f}")
# ── 6. Information flow ──
print("\n [6/8] Information flow analysis...")
flow = analyze_information_flow(interceptor)
all_results['information_flow'] = flow
for key, val in sorted(flow.items()):
print(f" {key}: {val:.6f}")
# ── 7. CV at every stage ──
print("\n [7/8] CV at pipeline stages...")
cv_stages = analyze_cv_at_stages(interceptor)
all_results['cv_stages'] = cv_stages
for key, data in cv_stages.items():
print(f" {data['name']:30s} CV={data['cv']:.4f} dim={data['dim']}")
# ── 8. Noise type fingerprints ──
print("\n [8/8] Noise type fingerprints...")
fingerprints = analyze_noise_fingerprints(interceptor, device)
all_results['noise_fingerprints'] = fingerprints
for name, fp in fingerprints.items():
print(f" {name:15s} S={['%.2f' % s for s in fp.get('S_mean', [])]}"
f" er={fp.get('erank', 0):.2f}"
f" Ξ”ca={fp.get('cross_attn_delta', 0):.5f}"
f" mse={fp.get('recon_mse', 0):.6f}")
# ── Weight analysis ──
print("\n [BONUS] Weight analysis...")
weights = analyze_weights(model)
all_results['weights'] = weights
total_params = sum(p.numel() for p in model.parameters())
total_sparse = sum(v['sparsity'] * np.prod(v['shape']) for v in weights.values())
print(f" Total params: {total_params:,}")
print(f" Effective sparsity: {total_sparse / total_params:.4f}")
# Key weight matrices
for name in ['enc_out.weight', 'dec_in.weight', 'dec_out.weight']:
if name in weights:
w = weights[name]
print(f" {name:25s} norm={w['norm']:.3f} cond={w.get('condition', 'N/A')}"
f" erank={w.get('erank', 'N/A')}")
elapsed = time.time() - t0
interceptor.remove_hooks()
print(f"\n{'=' * 70}")
print(f"BATTERY COMPLETE β€” {elapsed:.1f}s")
print(f"{'=' * 70}")
return all_results
# ═══════════════════════════════════════════════════════════════
# LOAD + RUN
# ═══════════════════════════════════════════════════════════════
def load_freckles(model_path=None, hf_version=None, device='cuda'):
from geolip_svae import load_model
if hf_version:
return load_model(hf_version=hf_version, device=device)
else:
return load_model(checkpoint_path=model_path, device=device)
if __name__ == "__main__":
MODEL = 'v41_freckles_256'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, cfg = load_freckles(hf_version=MODEL, device=device)
results = run_full_battery(model, device, img_size=cfg.get('img_size', 64))
# Save
def to_json(obj):
if isinstance(obj, (torch.Tensor, np.ndarray)):
if hasattr(obj, 'tolist'):
return obj.tolist()
return float(obj)
if isinstance(obj, dict):
return {str(k): to_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [to_json(v) for v in obj]
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return str(obj)
return obj
out_path = f'freckles_observer_{MODEL}.json'
with open(out_path, 'w') as f:
json.dump(to_json(results), f, indent=2)
print(f" Saved: {out_path}")