""" SVAE-Patch — Patch-based SVD Autoencoder ========================================== 64×64 image → 4 patches of 32×32 → independent SVD per patch → coordinate → decode. Each patch gets the proven (V, D) pipeline: patch → MLP → M ∈ ℝ^(V×D) → normalize → SVD → (U, S, Vt) Cross-patch coordination via a lightweight attention on the spectral representations — each patch's S vector (D-dim) attends to all others. This lets patches share information about relative spectral structure without disrupting the per-patch geometric attractors. Reconstruction: coordinated spectra + per-patch (U, Vt) → MLP → patch → stitch. """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as T import math import time # ── SVD Backend ────────────────────────────────────────────────── try: from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N HAS_FL = True except ImportError: HAS_FL = False def gram_eigh_svd_fp64(A): """Thin SVD via Gram + eigh in fp64.""" orig_dtype = A.dtype with torch.amp.autocast('cuda', enabled=False): A_d = A.double() G = torch.bmm(A_d.transpose(1, 2), A_d) eigenvalues, V = torch.linalg.eigh(G) eigenvalues = eigenvalues.flip(-1) V = V.flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) def svd_fp64(A): """Auto-dispatch SVD with fp64 internals.""" B, M, N = A.shape if HAS_FL and N <= _FL_MAX_N and A.is_cuda: orig_dtype = A.dtype with torch.amp.autocast('cuda', enabled=False): A_d = A.double() G = torch.bmm(A_d.transpose(1, 2), A_d) eigenvalues, V = FLEigh()(G.float()) eigenvalues = eigenvalues.double().flip(-1) V = V.double().flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) else: return gram_eigh_svd_fp64(A) # ── CV Monitoring ──────────────────────────────────────────────── def cayley_menger_vol2(points): """Squared simplex volume via Cayley-Menger determinant, fp64.""" B, N, D = points.shape pts = points.double() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) def cv_of(emb, n_samples=200): """CV of pentachoron volumes for a (V, D) embedding.""" if emb.dim() != 2 or emb.shape[0] < 5: return 0.0 N, D = emb.shape pool = min(N, 512) indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) vol2 = cayley_menger_vol2(emb[:pool][indices]) valid = vol2 > 1e-20 if valid.sum() < 10: return 0.0 vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # ── Data ───────────────────────────────────────────────────────── def get_tiny_imagenet(batch_size=256): """TinyImageNet via HuggingFace: 200 classes, 64x64.""" from datasets import load_dataset ds = load_dataset('zh-plus/tiny-imagenet') transform = T.Compose([ T.ToTensor(), T.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)), ]) class HFImageDataset(torch.utils.data.Dataset): def __init__(self, hf_split, transform): self.data = hf_split self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] img = item['image'] if img.mode != 'RGB': img = img.convert('RGB') return self.transform(img), item['label'] train_ds = HFImageDataset(ds['train'], transform) val_ds = HFImageDataset(ds['valid'], transform) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, val_loader def get_cifar10(batch_size=256): transform = T.Compose([ T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, test_loader def get_imagenet_128(batch_size=128): """ImageNet-1K at 128×128 via HuggingFace: 1000 classes, 1.28M train, 50K val. Requires: pip install datasets Requires: HF auth + ImageNet terms accepted on huggingface.co """ from datasets import load_dataset ds = load_dataset('benjamin-paine/imagenet-1k-128x128') transform = T.Compose([ T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # ImageNet stats ]) class HFImageDataset(torch.utils.data.Dataset): def __init__(self, hf_split, transform): self.data = hf_split self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] img = item['image'] if img.mode != 'RGB': img = img.convert('RGB') return self.transform(img), item['label'] train_ds = HFImageDataset(ds['train'], transform) val_ds = HFImageDataset(ds['validation'], transform) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, val_loader # ── Patch Utilities ────────────────────────────────────────────── def extract_patches(images, patch_size=32): """Split (B, 3, H, W) into patches of (B, n_patches, 3, ph, pw). Returns patches and grid dims (gh, gw) for reconstruction. """ B, C, H, W = images.shape ph, pw = patch_size, patch_size gh, gw = H // ph, W // pw # (B, C, gh, ph, gw, pw) → (B, gh, gw, C, ph, pw) → (B, n_patches, C*ph*pw) patches = images.reshape(B, C, gh, ph, gw, pw) patches = patches.permute(0, 2, 4, 1, 3, 5) # (B, gh, gw, C, ph, pw) patches = patches.reshape(B, gh * gw, C * ph * pw) return patches, gh, gw def stitch_patches(patches, gh, gw, patch_size=32): """Reassemble (B, n_patches, C*ph*pw) into (B, C, H, W).""" B = patches.shape[0] C = 3 ph, pw = patch_size, patch_size patches = patches.reshape(B, gh, gw, C, ph, pw) patches = patches.permute(0, 3, 1, 4, 2, 5) # (B, C, gh, ph, gw, pw) return patches.reshape(B, C, gh * ph, gw * pw) class BoundarySmooth(nn.Module): """Lightweight post-stitch boundary refinement. Two 3×3 convs with residual connection. Operates on the full stitched image. The receptive field (5×5) spans patch boundaries without reaching deep into patch interiors. ~600 params. Learns to blend seams without disrupting content. """ def __init__(self, channels=3, mid=16): super().__init__() self.net = nn.Sequential( nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(), nn.Conv2d(mid, channels, 3, padding=1), ) # Init near-zero so it starts as identity nn.init.zeros_(self.net[-1].weight) nn.init.zeros_(self.net[-1].bias) def forward(self, x): return x + self.net(x) # ── Spectral Cross-Attention ──────────────────────────────────── class SpectralCrossAttention(nn.Module): """Multiplicative spectral coordination with learnable per-mode alpha. S_out = S * (1 + α_d * tanh(attention_output_d)) Each spectral mode d learns its own coordination strength α_d. Alpha is parameterized through sigmoid for bounded [0, max_alpha] range: α_d = max_alpha * sigmoid(alpha_logit_d) This lets the model discover: - Which modes need cross-patch coordination (high α) - Which modes should stay independent (low α) - The global coordination budget (sum of alphas, regularizable) The alpha vector is a diagnostic: after training, it tells you which spectral modes carry inter-patch structure. """ def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0): super().__init__() self.n_heads = n_heads self.head_dim = D // n_heads self.max_alpha = max_alpha assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}" self.qkv = nn.Linear(D, 3 * D) self.out_proj = nn.Linear(D, D) self.norm = nn.LayerNorm(D) self.scale = self.head_dim ** -0.5 # Learnable per-mode alpha: initialized conservative (sigmoid(-2) ≈ 0.12) # so α starts at ~0.024 per mode (0.2 * 0.12) self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) @property def alpha(self): """Current per-mode alpha values, bounded [0, max_alpha].""" return self.max_alpha * torch.sigmoid(self.alpha_logits) def forward(self, S): """S: (B, n_patches, D) → coordinated S: (B, n_patches, D)""" B, N, D = S.shape S_normed = self.norm(S) qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, N, D) gate = torch.tanh(self.out_proj(out)) # Per-mode multiplicative modulation with learned strength alpha = self.alpha # (D,) return S * (1.0 + alpha.unsqueeze(0).unsqueeze(0) * gate) # ── Patch SVAE ─────────────────────────────────────────────────── class PatchSVAE(nn.Module): """Patch-based SVD Autoencoder. Image → patches → per-patch encode → sphere normalize → SVD → cross-patch spectral attention → per-patch decode → stitch. Each patch has its own geometric attractor determined by (V, D). Cross-patch attention coordinates spectral magnitudes only. U and V (directional structure) remain independent per patch. Args: matrix_v: Rows per patch matrix D: Embedding dimension patch_size: Spatial patch size (default 32 → 4 patches for 64×64) hidden: Per-patch MLP hidden width depth: Number of residual blocks in encoder and decoder (default 2) n_cross_layers: Number of spectral cross-attention layers """ def __init__(self, matrix_v=256, D=24, patch_size=32, hidden=512, depth=2, n_cross_layers=2): super().__init__() self.matrix_v = matrix_v self.D = D self.patch_size = patch_size self.patch_dim = 3 * patch_size * patch_size # 3072 for 32×32 self.mat_dim = matrix_v * D self.n_cross_layers = n_cross_layers # Per-patch encoder: project in → residual blocks → project out self.enc_in = nn.Linear(self.patch_dim, hidden) self.enc_blocks = nn.ModuleList([ nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden), ) for _ in range(depth) ]) self.enc_out = nn.Linear(hidden, self.mat_dim) # Per-patch decoder: project in → residual blocks → project out self.dec_in = nn.Linear(self.mat_dim, hidden) self.dec_blocks = nn.ModuleList([ nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden), ) for _ in range(depth) ]) self.dec_out = nn.Linear(hidden, self.patch_dim) nn.init.orthogonal_(self.enc_out.weight) # Cross-patch spectral coordination self.cross_attn = nn.ModuleList([ SpectralCrossAttention(D, n_heads=min(4, D)) for _ in range(n_cross_layers) ]) # Post-stitch boundary refinement self.boundary_smooth = BoundarySmooth(channels=3, mid=16) def encode_patches(self, patches): """Encode all patches in parallel. patches: (B, n_patches, patch_dim) Returns: per-patch SVD dicts + coordinated S. """ B, N, _ = patches.shape # Flatten batch and patches for shared encoder flat = patches.reshape(B * N, -1) h = F.gelu(self.enc_in(flat)) for block in self.enc_blocks: h = h + block(h) # residual M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D) M = F.normalize(M, dim=-1) # rows to S^(D-1) U, S, Vt = svd_fp64(M) # Reshape back to (B, N, ...) U = U.reshape(B, N, self.matrix_v, self.D) S = S.reshape(B, N, self.D) Vt = Vt.reshape(B, N, self.D, self.D) M = M.reshape(B, N, self.matrix_v, self.D) # Cross-patch spectral coordination S_coord = S for layer in self.cross_attn: S_coord = layer(S_coord) return { 'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M, } def decode_patches(self, U, S, Vt): """Decode from coordinated SVD. U: (B, N, V, D), S: (B, N, D), Vt: (B, N, D, D) Returns: (B, N, patch_dim) """ B, N, V, D = U.shape # Reconstruct per-patch matrices from coordinated S U_flat = U.reshape(B * N, V, D) S_flat = S.reshape(B * N, D) Vt_flat = Vt.reshape(B * N, D, D) M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1))) for block in self.dec_blocks: h = h + block(h) # residual patches = self.dec_out(h) return patches.reshape(B, N, -1) def forward(self, images): patches, gh, gw = extract_patches(images, self.patch_size) svd = self.encode_patches(patches) decoded_patches = self.decode_patches(svd['U'], svd['S'], svd['Vt']) recon = stitch_patches(decoded_patches, gh, gw, self.patch_size) recon = self.boundary_smooth(recon) return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw} @staticmethod def effective_rank(S): """Shannon entropy effective rank. S: (*, D).""" p = S / (S.sum(-1, keepdim=True) + 1e-8) p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() # ── Training ───────────────────────────────────────────────────── def train(epochs=200, lr=1e-3, V=256, D=24, patch_size=32, hidden=512, depth=2, target_cv=0.125, cv_weight=0.3, boost=0.5, sigma=0.15, n_cross_layers=2, dataset='tiny_imagenet', device='cuda', save_dir='/content/checkpoints', save_every=50, hf_repo='AbstractPhil/geolip-SVAE', hf_version='v11_prod', tb_dir='/content/runs'): """Train the PatchSVAE with sphere normalization + soft hand. Saves checkpoints to local + HuggingFace. Logs to TensorBoard. """ import os os.makedirs(save_dir, exist_ok=True) # ── TensorBoard ── from torch.utils.tensorboard import SummaryWriter run_name = f"patchsvae_V{V}_D{D}_p{patch_size}_h{hidden}_d{depth}" tb_path = os.path.join(tb_dir, run_name) writer = SummaryWriter(tb_path) print(f" TensorBoard: {tb_path}") # ── HuggingFace ── hf_enabled = False try: from huggingface_hub import HfApi api = HfApi() # Test auth api.whoami() hf_enabled = True hf_prefix = f"{hf_version}/checkpoints" print(f" HuggingFace: {hf_repo}/{hf_prefix}") except Exception as e: print(f" HuggingFace: disabled ({e})") def upload_to_hf(local_path, remote_name): """Upload a file to HF repo, non-blocking on failure.""" if not hf_enabled: return try: remote_path = f"{hf_prefix}/{remote_name}" api.upload_file( path_or_fileobj=local_path, path_in_repo=remote_path, repo_id=hf_repo, repo_type="model", ) print(f" ☁️ Uploaded: {hf_repo}/{remote_path}") except Exception as e: print(f" ⚠️ HF upload failed: {e}") device = torch.device(device if torch.cuda.is_available() else 'cpu') if dataset == 'imagenet_128': train_loader, test_loader = get_imagenet_128(batch_size=128) img_h = 128 n_classes = 1000 mean_t = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).to(device) elif dataset == 'tiny_imagenet': train_loader, test_loader = get_tiny_imagenet(batch_size=256) img_h = 64 n_classes = 200 mean_t = torch.tensor([0.4802, 0.4481, 0.3975]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.2770, 0.2691, 0.2821]).reshape(1, 3, 1, 1).to(device) else: train_loader, test_loader = get_cifar10(batch_size=256) img_h = 32 n_classes = 10 mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) n_patches = (img_h // patch_size) ** 2 model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size, hidden=hidden, depth=depth, n_cross_layers=n_cross_layers).to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) total_params = sum(p.numel() for p in model.parameters()) cross_params = sum(p.numel() for p in model.cross_attn.parameters()) svd_backend = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})" print(f"Using geolip-core SVD ({svd_backend})") print(f"PatchSVAE - {n_patches} patches of {patch_size}×{patch_size}") print(f" Dataset: {dataset} ({img_h}×{img_h}, {n_classes} classes)") print(f" Per-patch: ({V}, {D}) = {V*D} elements, rows on S^{D-1}") print(f" Encoder/Decoder: hidden={hidden}, depth={depth} (residual blocks)") print(f" Cross-attention: {n_cross_layers} layers on S vectors ({cross_params:,} params)") print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far") print(f" Total params: {total_params:,}") print(f" Checkpoints: {save_dir} (best + every {save_every} epochs)") print("=" * 95) print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | " f"{'t_rec':>7} | " f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " f"{'row_cv':>7} {'prox':>5} {'rw':>5} | " f"{'S_delta':>7}") print("-" * 95) best_recon = float('inf') def save_checkpoint(path, epoch, test_mse, extra=None, upload=True): """Save model + optimizer + config. Optionally upload to HF.""" ckpt = { 'epoch': epoch, 'test_mse': test_mse, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': opt.state_dict(), 'scheduler_state_dict': sched.state_dict(), 'config': { 'V': V, 'D': D, 'patch_size': patch_size, 'hidden': hidden, 'depth': depth, 'n_cross_layers': n_cross_layers, 'target_cv': target_cv, 'cv_weight': cv_weight, 'boost': boost, 'sigma': sigma, 'dataset': dataset, 'lr': lr, }, } if extra: ckpt.update(extra) torch.save(ckpt, path) size_mb = os.path.getsize(path) / (1024 * 1024) print(f" 💾 Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})") if upload: upload_to_hf(path, os.path.basename(path)) for epoch in range(1, epochs + 1): model.train() total_loss, total_recon, n = 0, 0, 0 last_cv = target_cv last_prox = 1.0 recon_w = 1.0 + boost t0 = time.time() for batch_idx, (images, labels) in enumerate(train_loader): images = images.to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) # CV from first patch of first image with torch.no_grad(): if batch_idx % 10 == 0: current_cv = cv_of(out['svd']['M'][0, 0]) # first image, first patch if current_cv > 0: last_cv = current_cv delta = last_cv - target_cv last_prox = math.exp(-delta**2 / (2 * sigma**2)) recon_w = 1.0 + boost * last_prox cv_pen = cv_weight * (1.0 - last_prox) cv_l = (last_cv - target_cv) ** 2 loss = recon_w * recon_loss + cv_pen * cv_l loss.backward() # Clip cross-attention gradients only — the cascade source torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5) opt.step() total_loss += loss.item() * len(images) total_recon += recon_loss.item() * len(images) n += len(images) sched.step() epoch_time = time.time() - t0 # ── TensorBoard: train metrics (every epoch, cheap) ── writer.add_scalar('train/loss', total_loss / n, epoch) writer.add_scalar('train/recon', total_recon / n, epoch) writer.add_scalar('train/lr', sched.get_last_lr()[0], epoch) writer.add_scalar('train/proximity', last_prox, epoch) writer.add_scalar('train/recon_weight', recon_w, epoch) writer.add_scalar('train/epoch_time', epoch_time, epoch) # ── Evaluation ── if epoch % 2 == 0 or epoch <= 3: model.eval() test_recon, test_n = 0, 0 test_S_orig, test_S_coord = None, None test_erank = 0 row_cvs = [] nb = 0 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) test_recon += F.mse_loss(out['recon'], images).item() * len(images) test_n += len(images) # Average over patches and batch S_mean = out['svd']['S'].mean(dim=(0, 1)) # (D,) S_orig_mean = out['svd']['S_orig'].mean(dim=(0, 1)) test_erank += model.effective_rank(out['svd']['S'].reshape(-1, D)).mean().item() if nb < 3: for b in range(min(2, len(images))): for p in range(min(2, out['svd']['M'].shape[1])): row_cvs.append(cv_of(out['svd']['M'][b, p])) if test_S_orig is None: test_S_orig = S_orig_mean.cpu() test_S_coord = S_mean.cpu() else: test_S_orig += S_orig_mean.cpu() test_S_coord += S_mean.cpu() nb += 1 test_erank /= nb test_S_orig /= nb test_S_coord /= nb ratio = (test_S_coord[0] / (test_S_coord[-1] + 1e-8)).item() mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0 # How much did cross-attention change S? s_delta = (test_S_coord - test_S_orig).abs().mean().item() # Alpha diagnostics: mean and max across all cross-attention layers with torch.no_grad(): all_alphas = torch.cat([layer.alpha for layer in model.cross_attn]) alpha_mean = all_alphas.mean().item() alpha_max = all_alphas.max().item() print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | " f"{test_recon/test_n:7.4f} | " f"{test_S_coord[0]:6.3f} {test_S_coord[-1]:6.3f} {ratio:5.2f} " f"{test_erank:5.2f} | " f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f} | " f"{s_delta:7.5f} a:{alpha_mean:.4f}/{alpha_max:.4f}") # ── TensorBoard: eval metrics + geometry ── test_mse = test_recon / test_n writer.add_scalar('test/recon_mse', test_mse, epoch) writer.add_scalar('test/best_mse', min(best_recon, test_mse), epoch) # Geometry writer.add_scalar('geo/row_cv', mean_cv, epoch) writer.add_scalar('geo/ratio', ratio, epoch) writer.add_scalar('geo/erank', test_erank, epoch) writer.add_scalar('geo/S0', test_S_coord[0].item(), epoch) writer.add_scalar('geo/SD', test_S_coord[-1].item(), epoch) # Cross-attention writer.add_scalar('cross_attn/s_delta', s_delta, epoch) writer.add_scalar('cross_attn/alpha_mean', alpha_mean, epoch) writer.add_scalar('cross_attn/alpha_max', alpha_max, epoch) # Soft hand dynamics writer.add_scalar('soft_hand/proximity', last_prox, epoch) writer.add_scalar('soft_hand/recon_weight', recon_w, epoch) # Singular value profile (every 20 epochs — histogram is heavier) if epoch % 20 == 0 or epoch <= 3: writer.add_histogram('spectrum/S_coordinated', test_S_coord, epoch) writer.add_histogram('spectrum/S_original', test_S_orig, epoch) # Per-layer alpha for li, layer in enumerate(model.cross_attn): writer.add_histogram(f'alpha/layer_{li}', layer.alpha.detach().cpu(), epoch) # Recon images (every 50 epochs — image writes are expensive) if epoch % 50 == 0 or epoch == 1: with torch.no_grad(): sample_imgs, _ = next(iter(test_loader)) sample_imgs = sample_imgs[:8].to(device) sample_out = model(sample_imgs) sample_recon = sample_out['recon'] # Denormalize for visualization orig_vis = (sample_imgs * std_t + mean_t).clamp(0, 1) recon_vis = (sample_recon * std_t + mean_t).clamp(0, 1) comparison = torch.cat([orig_vis, recon_vis], dim=0) grid = torchvision.utils.make_grid(comparison, nrow=8, padding=2) writer.add_image('recon/comparison', grid, epoch) # ── Checkpoint saving ── geo_stats = { 'row_cv': mean_cv, 'ratio': ratio, 'erank': test_erank, 'S0': test_S_coord[0].item(), 'SD': test_S_coord[-1].item(), 's_delta': s_delta, 'alpha_mean': alpha_mean, 'alpha_max': alpha_max, } # Best model — save locally always, upload only at periodic intervals if test_mse < best_recon: best_recon = test_mse save_checkpoint( os.path.join(save_dir, 'best.pt'), epoch, test_mse, extra={'geo': geo_stats}, upload=False) # local only # Periodic save + upload best if improved if epoch % save_every == 0: save_checkpoint( os.path.join(save_dir, f'epoch_{epoch:04d}.pt'), epoch, test_mse, extra={'geo': geo_stats}) # Also push current best to HF best_path = os.path.join(save_dir, 'best.pt') if os.path.exists(best_path): upload_to_hf(best_path, 'best.pt') # Flush + upload TB logs writer.flush() if hf_enabled: try: api.upload_folder( folder_path=tb_path, path_in_repo=f"{hf_version}/tensorboard/{run_name}", repo_id=hf_repo, repo_type="model", ) print(f" ☁️ TB logs synced to {hf_repo}") except Exception as e: print(f" ⚠️ TB sync failed: {e}") # ── Final Analysis ── print() print("=" * 90) print("FINAL ANALYSIS") print("=" * 90) model.eval() all_recon_err = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) all_recon_err.append( F.mse_loss(out['recon'], images, reduction='none') .mean(dim=(1, 2, 3)).cpu()) all_recon_err = torch.cat(all_recon_err) print(f"\n PatchSVAE: {n_patches} patches × ({V}, {D})") print(f" Target CV: {target_cv}") print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}") print(f" Row CV: {mean_cv:.4f}") print(f" Cross-attention S delta: {s_delta:.5f}") # Per-mode alpha profile — which modes coordinate between patches? print(f"\n Learned alpha per mode (coordination strength):") for layer_idx, layer in enumerate(model.cross_attn): alpha = layer.alpha.detach().cpu() print(f" Layer {layer_idx}: mean={alpha.mean():.4f} max={alpha.max():.4f} min={alpha.min():.4f}") bar_scale = 40 / (alpha.max().item() + 1e-8) for d in range(len(alpha)): bar = "#" * int(alpha[d].item() * bar_scale) print(f" α[{d:2d}]: {alpha[d]:.4f} {bar}") print(f"\n Coordinated singular value profile:") total_energy = (test_S_coord ** 2).sum() cumulative = 0 for i in range(len(test_S_coord)): e = (test_S_coord[i] ** 2).item() cumulative += e pct = cumulative / total_energy * 100 bar = "#" * int(test_S_coord[i].item() * 30 / (test_S_coord[0].item() + 1e-8)) print(f" S[{i:2d}]: {test_S_coord[i]:8.4f} cum={pct:5.1f}% {bar}") # ── Reconstruction Grid ── print(f"\n Saving reconstruction grid...") import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt model.eval() with torch.no_grad(): images, labels = next(iter(test_loader)) images = images[:20].to(device) out = model(images) recon = out['recon'] def denorm(t): return (t * std_t + mean_t).clamp(0, 1).cpu() n_show = min(10, len(images)) fig, axes = plt.subplots(n_show, 3, figsize=(6, n_show * 2)) for i in range(n_show): axes[i, 0].imshow(denorm(images[i:i+1])[0].permute(1, 2, 0).numpy()) axes[i, 1].imshow(denorm(recon[i:i+1])[0].permute(1, 2, 0).numpy()) diff = (denorm(images[i:i+1]) - denorm(recon[i:i+1])).abs() * 5 axes[i, 2].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()) axes[0, 0].set_title('Original', fontsize=8) axes[0, 1].set_title('Recon', fontsize=8) axes[0, 2].set_title('|Err|×5', fontsize=8) for ax in axes.flat: ax.axis('off') plt.tight_layout() plt.savefig('/content/svae_patch_recon.png', dpi=200, bbox_inches='tight') print(f" Saved to /content/svae_patch_recon.png") try: plt.show() except: pass plt.close() # ── Final checkpoint ── save_checkpoint( os.path.join(save_dir, 'final.pt'), epochs, all_recon_err.mean().item(), extra={'geo': geo_stats} ) # ── Close TensorBoard + upload logs ── writer.close() if hf_enabled: try: api.upload_folder( folder_path=tb_path, path_in_repo=f"{hf_version}/tensorboard/{run_name}", repo_id=hf_repo, repo_type="model", ) print(f" ☁️ TB logs uploaded to {hf_repo}/{hf_version}/tensorboard/") except Exception as e: print(f" ⚠️ TB upload failed: {e}") # Upload recon grid recon_grid_path = '/content/svae_patch_recon.png' if os.path.exists(recon_grid_path): upload_to_hf(recon_grid_path, 'recon_grid.png') print(f"\n Best MSE: {best_recon:.6f}") print(f" Checkpoints: {save_dir}/") print(f" TensorBoard: {tb_path}") print(f" HuggingFace: {hf_repo}/{hf_version}/") if __name__ == "__main__": # ImageNet-1K 128×128: 64 patches of 16×16, each (256, 16) # 1000 classes, 1.28M train images # depth=4 residual blocks, hidden=768, learned alpha coordination train(epochs=200, lr=1e-4, V=256, D=16, patch_size=16, hidden=768, depth=4, target_cv=0.125, n_cross_layers=2, dataset='imagenet_128', hf_version='v12_imagenet128')