""" SVAE-Patch — Patch-based SVD Autoencoder ========================================== Image → patches → per-patch encode → sphere normalize → SVD → cross-patch spectral attention → per-patch decode → stitch → boundary smooth. Proven configurations: TinyImageNet 64×64: 16 patches of 16×16, (256,16), 0.000478 MSE, 17M params ImageNet-1K 128×128: 64 patches of 16×16, (256,16), 0.000206 MSE, 17M params ImageNet-1K 256×256: 256 patches of 16×16, (256,16), 17M params """ import os import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as T import math import time from tqdm import tqdm # ── SVD Backend ────────────────────────────────────────────────── try: from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N HAS_FL = True except ImportError: HAS_FL = False def gram_eigh_svd_fp64(A): """Thin SVD via Gram + eigh in fp64.""" orig_dtype = A.dtype with torch.amp.autocast('cuda', enabled=False): A_d = A.double() G = torch.bmm(A_d.transpose(1, 2), A_d) eigenvalues, V = torch.linalg.eigh(G) eigenvalues = eigenvalues.flip(-1) V = V.flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) def svd_fp64(A): """Auto-dispatch SVD with fp64 internals.""" B, M, N = A.shape if HAS_FL and N <= _FL_MAX_N and A.is_cuda: orig_dtype = A.dtype with torch.amp.autocast('cuda', enabled=False): A_d = A.double() G = torch.bmm(A_d.transpose(1, 2), A_d) eigenvalues, V = FLEigh()(G.float()) eigenvalues = eigenvalues.double().flip(-1) V = V.double().flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) else: return gram_eigh_svd_fp64(A) # ── CV Monitoring ──────────────────────────────────────────────── def cayley_menger_vol2(points): """Squared simplex volume via Cayley-Menger determinant, fp64.""" B, N, D = points.shape pts = points.double() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) def cv_of(emb, n_samples=200): """CV of pentachoron volumes for a (V, D) embedding.""" if emb.dim() != 2 or emb.shape[0] < 5: return 0.0 N, D = emb.shape pool = min(N, 512) indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) vol2 = cayley_menger_vol2(emb[:pool][indices]) valid = vol2 > 1e-20 if valid.sum() < 10: return 0.0 vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # ── Data ───────────────────────────────────────────────────────── class HFImageDataset(torch.utils.data.Dataset): def __init__(self, hf_split, transform): self.data = hf_split self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] img = item['image'] if img.mode != 'RGB': img = img.convert('RGB') return self.transform(img), item['label'] def get_tiny_imagenet(batch_size=256): from datasets import load_dataset ds = load_dataset('zh-plus/tiny-imagenet') transform = T.Compose([ T.ToTensor(), T.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)), ]) train_ds = HFImageDataset(ds['train'], transform) val_ds = HFImageDataset(ds['valid'], transform) return (torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2), torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)) def get_cifar10(batch_size=256): transform = T.Compose([ T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) return (torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2), torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)) def get_imagenet_128(batch_size=128): from datasets import load_dataset ds = load_dataset('benjamin-paine/imagenet-1k-128x128') transform = T.Compose([ T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) train_ds = HFImageDataset(ds['train'], transform) val_ds = HFImageDataset(ds['validation'], transform) return (torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) def get_imagenet_256(batch_size=64): from datasets import load_dataset ds = load_dataset('benjamin-paine/imagenet-1k-256x256') transform = T.Compose([ T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) train_ds = HFImageDataset(ds['train'], transform) val_ds = HFImageDataset(ds['validation'], transform) return (torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) # ── Patch Utilities ────────────────────────────────────────────── def extract_patches(images, patch_size=16): B, C, H, W = images.shape ph, pw = patch_size, patch_size gh, gw = H // ph, W // pw patches = images.reshape(B, C, gh, ph, gw, pw) patches = patches.permute(0, 2, 4, 1, 3, 5) patches = patches.reshape(B, gh * gw, C * ph * pw) return patches, gh, gw def stitch_patches(patches, gh, gw, patch_size=16): B = patches.shape[0] C = 3 ph, pw = patch_size, patch_size patches = patches.reshape(B, gh, gw, C, ph, pw) patches = patches.permute(0, 3, 1, 4, 2, 5) return patches.reshape(B, C, gh * ph, gw * pw) class BoundarySmooth(nn.Module): """Post-stitch boundary refinement. ~600 params, zero-initialized.""" def __init__(self, channels=3, mid=16): super().__init__() self.net = nn.Sequential( nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(), nn.Conv2d(mid, channels, 3, padding=1), ) nn.init.zeros_(self.net[-1].weight) nn.init.zeros_(self.net[-1].bias) def forward(self, x): return x + self.net(x) # ── Spectral Cross-Attention ──────────────────────────────────── class SpectralCrossAttention(nn.Module): """Multiplicative spectral coordination with learnable per-mode alpha. S_out = S * (1 + α_d * tanh(attention_output_d)) """ def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0): super().__init__() self.n_heads = n_heads self.head_dim = D // n_heads self.max_alpha = max_alpha assert D % n_heads == 0 self.qkv = nn.Linear(D, 3 * D) self.out_proj = nn.Linear(D, D) self.norm = nn.LayerNorm(D) self.scale = self.head_dim ** -0.5 self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) @property def alpha(self): return self.max_alpha * torch.sigmoid(self.alpha_logits) def forward(self, S): B, N, D = S.shape S_normed = self.norm(S) qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, N, D) gate = torch.tanh(self.out_proj(out)) alpha = self.alpha return S * (1.0 + alpha.unsqueeze(0).unsqueeze(0) * gate) # ── Patch SVAE ─────────────────────────────────────────────────── class PatchSVAE(nn.Module): def __init__(self, matrix_v=256, D=16, patch_size=16, hidden=768, depth=4, n_cross_layers=2): super().__init__() self.matrix_v = matrix_v self.D = D self.patch_size = patch_size self.patch_dim = 3 * patch_size * patch_size self.mat_dim = matrix_v * D self.enc_in = nn.Linear(self.patch_dim, hidden) self.enc_blocks = nn.ModuleList([ nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth) ]) self.enc_out = nn.Linear(hidden, self.mat_dim) self.dec_in = nn.Linear(self.mat_dim, hidden) self.dec_blocks = nn.ModuleList([ nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth) ]) self.dec_out = nn.Linear(hidden, self.patch_dim) nn.init.orthogonal_(self.enc_out.weight) self.cross_attn = nn.ModuleList([ SpectralCrossAttention(D, n_heads=min(4, D)) for _ in range(n_cross_layers) ]) self.boundary_smooth = BoundarySmooth(channels=3, mid=16) def encode_patches(self, patches): B, N, _ = patches.shape flat = patches.reshape(B * N, -1) h = F.gelu(self.enc_in(flat)) for block in self.enc_blocks: h = h + block(h) M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D) M = F.normalize(M, dim=-1) U, S, Vt = svd_fp64(M) U = U.reshape(B, N, self.matrix_v, self.D) S = S.reshape(B, N, self.D) Vt = Vt.reshape(B, N, self.D, self.D) M = M.reshape(B, N, self.matrix_v, self.D) S_coord = S for layer in self.cross_attn: S_coord = layer(S_coord) return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M} def decode_patches(self, U, S, Vt): B, N, V, D = U.shape U_flat = U.reshape(B * N, V, D) S_flat = S.reshape(B * N, D) Vt_flat = Vt.reshape(B * N, D, D) M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1))) for block in self.dec_blocks: h = h + block(h) patches = self.dec_out(h) return patches.reshape(B, N, -1) def forward(self, images): patches, gh, gw = extract_patches(images, self.patch_size) svd = self.encode_patches(patches) decoded_patches = self.decode_patches(svd['U'], svd['S'], svd['Vt']) recon = stitch_patches(decoded_patches, gh, gw, self.patch_size) recon = self.boundary_smooth(recon) return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw} @staticmethod def effective_rank(S): p = S / (S.sum(-1, keepdim=True) + 1e-8) p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() # ── Training ───────────────────────────────────────────────────── def train(epochs=200, lr=1e-4, V=256, D=16, patch_size=16, hidden=768, depth=4, batch_size=None, target_cv=0.125, cv_weight=0.3, boost=0.5, sigma=0.15, n_cross_layers=2, dataset='tiny_imagenet', device='cuda', save_dir='/content/checkpoints', save_every=1, hf_repo='AbstractPhil/geolip-SVAE', hf_version='v11_prod', tb_dir='/content/runs'): os.makedirs(save_dir, exist_ok=True) device = torch.device(device if torch.cuda.is_available() else 'cpu') # ── TensorBoard ── from torch.utils.tensorboard import SummaryWriter run_name = f"patchsvae_V{V}_D{D}_p{patch_size}_h{hidden}_d{depth}" tb_path = os.path.join(tb_dir, run_name) writer = SummaryWriter(tb_path) print(f" TensorBoard: {tb_path}") # ── HuggingFace ── hf_enabled = False try: from huggingface_hub import HfApi api = HfApi() api.whoami() hf_enabled = True hf_prefix = f"{hf_version}/checkpoints" print(f" HuggingFace: {hf_repo}/{hf_prefix}") except Exception as e: print(f" HuggingFace: disabled ({e})") def upload_to_hf(local_path, remote_name): if not hf_enabled: return try: remote_path = f"{hf_prefix}/{remote_name}" api.upload_file(path_or_fileobj=local_path, path_in_repo=remote_path, repo_id=hf_repo, repo_type="model") print(f" ☁️ Uploaded: {hf_repo}/{remote_path}") except Exception as e: print(f" ⚠️ HF upload failed: {e}") # ── Dataset ── if dataset == 'imagenet_256': bs = batch_size or 64 train_loader, test_loader = get_imagenet_256(batch_size=bs) img_h, n_classes = 256, 1000 mean_t = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).to(device) elif dataset == 'imagenet_128': bs = batch_size or 128 train_loader, test_loader = get_imagenet_128(batch_size=bs) img_h, n_classes = 128, 1000 mean_t = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).to(device) elif dataset == 'tiny_imagenet': bs = batch_size or 256 train_loader, test_loader = get_tiny_imagenet(batch_size=bs) img_h, n_classes = 64, 200 mean_t = torch.tensor([0.4802, 0.4481, 0.3975]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.2770, 0.2691, 0.2821]).reshape(1, 3, 1, 1).to(device) else: bs = batch_size or 256 train_loader, test_loader = get_cifar10(batch_size=bs) img_h, n_classes = 32, 10 mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) # ── Model ── n_patches = (img_h // patch_size) ** 2 model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size, hidden=hidden, depth=depth, n_cross_layers=n_cross_layers).to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) total_params = sum(p.numel() for p in model.parameters()) cross_params = sum(p.numel() for p in model.cross_attn.parameters()) svd_info = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})" print(f"Using geolip-core SVD ({svd_info})") print(f"PatchSVAE - {n_patches} patches of {patch_size}×{patch_size}") print(f" Dataset: {dataset} ({img_h}×{img_h}, {n_classes} classes, batch={bs})") print(f" Per-patch: ({V}, {D}) = {V*D} elements, rows on S^{D-1}") print(f" Encoder/Decoder: hidden={hidden}, depth={depth} (residual blocks)") print(f" Cross-attention: {n_cross_layers} layers on S vectors ({cross_params:,} params)") print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far") print(f" Total params: {total_params:,}") print(f" Checkpoints: {save_dir} (best + every {save_every} epochs)") print("=" * 100) print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | " f"{'t_rec':>7} | " f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " f"{'row_cv':>7} {'prox':>5} {'rw':>5} | " f"{'S_delta':>7}") print("-" * 100) best_recon = float('inf') def save_checkpoint(path, epoch, test_mse, extra=None, upload=True): ckpt = { 'epoch': epoch, 'test_mse': test_mse, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': opt.state_dict(), 'scheduler_state_dict': sched.state_dict(), 'config': { 'V': V, 'D': D, 'patch_size': patch_size, 'hidden': hidden, 'depth': depth, 'n_cross_layers': n_cross_layers, 'target_cv': target_cv, 'cv_weight': cv_weight, 'boost': boost, 'sigma': sigma, 'dataset': dataset, 'lr': lr, }, } if extra: ckpt.update(extra) torch.save(ckpt, path) size_mb = os.path.getsize(path) / (1024 * 1024) print(f" 💾 Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})") if upload: upload_to_hf(path, os.path.basename(path)) # ── Training Loop ── for epoch in range(1, epochs + 1): model.train() total_loss, total_recon, n = 0, 0, 0 last_cv, last_prox, recon_w = target_cv, 1.0, 1.0 + boost t0 = time.time() pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}", leave=False, bar_format='{l_bar}{bar:20}{r_bar}') for batch_idx, (images, labels) in enumerate(pbar): images = images.to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) with torch.no_grad(): if batch_idx % 10 == 0: current_cv = cv_of(out['svd']['M'][0, 0]) if current_cv > 0: last_cv = current_cv delta = last_cv - target_cv last_prox = math.exp(-delta**2 / (2 * sigma**2)) pbar.set_postfix_str( f"loss={recon_loss.item():.4f} cv={last_cv:.3f} prox={last_prox:.2f}", refresh=False) recon_w = 1.0 + boost * last_prox cv_pen = cv_weight * (1.0 - last_prox) cv_l = (last_cv - target_cv) ** 2 loss = recon_w * recon_loss + cv_pen * cv_l loss.backward() torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5) opt.step() total_loss += loss.item() * len(images) total_recon += recon_loss.item() * len(images) n += len(images) pbar.close() sched.step() epoch_time = time.time() - t0 # ── TensorBoard: train ── writer.add_scalar('train/loss', total_loss / n, epoch) writer.add_scalar('train/recon', total_recon / n, epoch) writer.add_scalar('train/lr', sched.get_last_lr()[0], epoch) writer.add_scalar('train/epoch_time', epoch_time, epoch) # ── Evaluation ── if epoch % 2 == 0 or epoch <= 3: model.eval() test_recon, test_n = 0, 0 test_S_orig, test_S_coord = None, None test_erank, row_cvs, nb = 0, [], 0 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) test_recon += F.mse_loss(out['recon'], images).item() * len(images) test_n += len(images) S_mean = out['svd']['S'].mean(dim=(0, 1)) S_orig_mean = out['svd']['S_orig'].mean(dim=(0, 1)) test_erank += model.effective_rank(out['svd']['S'].reshape(-1, D)).mean().item() if nb < 3: for b in range(min(2, len(images))): for p in range(min(2, out['svd']['M'].shape[1])): row_cvs.append(cv_of(out['svd']['M'][b, p])) if test_S_orig is None: test_S_orig = S_orig_mean.cpu() test_S_coord = S_mean.cpu() else: test_S_orig += S_orig_mean.cpu() test_S_coord += S_mean.cpu() nb += 1 test_erank /= nb test_S_orig /= nb test_S_coord /= nb ratio = (test_S_coord[0] / (test_S_coord[-1] + 1e-8)).item() mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0 s_delta = (test_S_coord - test_S_orig).abs().mean().item() with torch.no_grad(): all_alphas = torch.cat([layer.alpha for layer in model.cross_attn]) alpha_mean = all_alphas.mean().item() alpha_max = all_alphas.max().item() print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | " f"{test_recon/test_n:7.4f} | " f"{test_S_coord[0]:6.3f} {test_S_coord[-1]:6.3f} {ratio:5.2f} " f"{test_erank:5.2f} | " f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f} | " f"{s_delta:7.5f} a:{alpha_mean:.4f}/{alpha_max:.4f}") # ── TensorBoard: eval ── test_mse = test_recon / test_n writer.add_scalar('test/recon_mse', test_mse, epoch) writer.add_scalar('geo/row_cv', mean_cv, epoch) writer.add_scalar('geo/ratio', ratio, epoch) writer.add_scalar('geo/erank', test_erank, epoch) writer.add_scalar('geo/S0', test_S_coord[0].item(), epoch) writer.add_scalar('geo/SD', test_S_coord[-1].item(), epoch) writer.add_scalar('cross_attn/s_delta', s_delta, epoch) writer.add_scalar('cross_attn/alpha_mean', alpha_mean, epoch) if epoch % 20 == 0 or epoch <= 3: writer.add_histogram('spectrum/S_coordinated', test_S_coord, epoch) for li, layer in enumerate(model.cross_attn): writer.add_histogram(f'alpha/layer_{li}', layer.alpha.detach().cpu(), epoch) if epoch % 50 == 0 or epoch == 1: with torch.no_grad(): sample_imgs, _ = next(iter(test_loader)) sample_imgs = sample_imgs[:8].to(device) sample_out = model(sample_imgs) orig_vis = (sample_imgs * std_t + mean_t).clamp(0, 1) recon_vis = (sample_out['recon'] * std_t + mean_t).clamp(0, 1) comparison = torch.cat([orig_vis, recon_vis], dim=0) grid = torchvision.utils.make_grid(comparison, nrow=8, padding=2) writer.add_image('recon/comparison', grid, epoch) # ── Checkpoints ── geo_stats = { 'row_cv': mean_cv, 'ratio': ratio, 'erank': test_erank, 'S0': test_S_coord[0].item(), 'SD': test_S_coord[-1].item(), 's_delta': s_delta, 'alpha_mean': alpha_mean, 'alpha_max': alpha_max, } if test_mse < best_recon: best_recon = test_mse save_checkpoint(os.path.join(save_dir, 'best.pt'), epoch, test_mse, extra={'geo': geo_stats}, upload=False) if epoch % save_every == 0: save_checkpoint(os.path.join(save_dir, f'epoch_{epoch:04d}.pt'), epoch, test_mse, extra={'geo': geo_stats}) best_path = os.path.join(save_dir, 'best.pt') if os.path.exists(best_path): upload_to_hf(best_path, 'best.pt') writer.flush() if hf_enabled: try: api.upload_folder(folder_path=tb_path, path_in_repo=f"{hf_version}/tensorboard/{run_name}", repo_id=hf_repo, repo_type="model") print(f" ☁️ TB logs synced") except Exception as e: print(f" ⚠️ TB sync failed: {e}") # ── Final Analysis ── print("\n" + "=" * 90) print("FINAL ANALYSIS") print("=" * 90) model.eval() all_recon_err = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) all_recon_err.append( F.mse_loss(out['recon'], images, reduction='none') .mean(dim=(1, 2, 3)).cpu()) all_recon_err = torch.cat(all_recon_err) print(f"\n PatchSVAE: {n_patches} patches × ({V}, {D})") print(f" Target CV: {target_cv}") print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}") print(f"\n Learned alpha per mode:") for li, layer in enumerate(model.cross_attn): alpha = layer.alpha.detach().cpu() print(f" Layer {li}: mean={alpha.mean():.4f} max={alpha.max():.4f} min={alpha.min():.4f}") bar_scale = 40 / (alpha.max().item() + 1e-8) for d in range(len(alpha)): bar = "#" * int(alpha[d].item() * bar_scale) print(f" α[{d:2d}]: {alpha[d]:.4f} {bar}") print(f"\n Singular value profile:") total_energy = (test_S_coord ** 2).sum() cumulative = 0 for i in range(len(test_S_coord)): e = (test_S_coord[i] ** 2).item() cumulative += e pct = cumulative / total_energy * 100 bar = "#" * int(test_S_coord[i].item() * 30 / (test_S_coord[0].item() + 1e-8)) print(f" S[{i:2d}]: {test_S_coord[i]:8.4f} cum={pct:5.1f}% {bar}") # ── Recon grid ── import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt with torch.no_grad(): images, labels = next(iter(test_loader)) images = images[:20].to(device) recon = model(images)['recon'] def denorm(t): return (t * std_t + mean_t).clamp(0, 1).cpu() n_show = min(10, len(images)) fig, axes = plt.subplots(n_show, 3, figsize=(6, n_show * 2)) for i in range(n_show): axes[i, 0].imshow(denorm(images[i:i+1])[0].permute(1, 2, 0).numpy()) axes[i, 1].imshow(denorm(recon[i:i+1])[0].permute(1, 2, 0).numpy()) diff = (denorm(images[i:i+1]) - denorm(recon[i:i+1])).abs() * 5 axes[i, 2].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()) axes[0, 0].set_title('Original', fontsize=8) axes[0, 1].set_title('Recon', fontsize=8) axes[0, 2].set_title('|Err|×5', fontsize=8) for ax in axes.flat: ax.axis('off') plt.tight_layout() plt.savefig('/content/svae_patch_recon.png', dpi=200, bbox_inches='tight') print(f"\n Saved to /content/svae_patch_recon.png") try: plt.show() except: pass plt.close() # ── Final save ── save_checkpoint(os.path.join(save_dir, 'final.pt'), epochs, all_recon_err.mean().item(), extra={'geo': geo_stats}) writer.close() if hf_enabled: try: api.upload_folder(folder_path=tb_path, path_in_repo=f"{hf_version}/tensorboard/{run_name}", repo_id=hf_repo, repo_type="model") except: pass recon_grid_path = '/content/svae_patch_recon.png' if os.path.exists(recon_grid_path): upload_to_hf(recon_grid_path, 'recon_grid.png') print(f"\n Best MSE: {best_recon:.6f}") print(f" Checkpoints: {save_dir}/") print(f" TensorBoard: {tb_path}") print(f" HuggingFace: {hf_repo}/{hf_version}/") if __name__ == "__main__": torch.set_float32_matmul_precision('high') # ImageNet-1K 256×256: 256 patches of 16×16, each (256, 16) # Latent: (16, 16, 16) = 4096 values — diffusion-ready train(epochs=20, lr=1e-4, V=256, D=16, patch_size=16, hidden=768, depth=4, batch_size=64, target_cv=0.2915, n_cross_layers=2, dataset='imagenet_256', hf_version='v13_imagenet256', save_every=1)