""" SVAE v1 — Clean SVD Autoencoder ================================== The version that works. 0.071 MSE at 384:1 compression. Image → encoder MLP → matrix (32×32) → real SVD → keep 8 → decode → image Spectral concentration pushes energy into the kept 8. No extra machinery. The SVD IS the bottleneck. """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as T import math # ── Data ── def get_cifar10(batch_size=256): transform = T.Compose([ T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, test_loader # ── SVAE ── class SVAE(nn.Module): def __init__(self, matrix_h=32, matrix_k=32, keep_k=8): super().__init__() self.matrix_h = matrix_h self.matrix_k = matrix_k self.keep_k = keep_k self.img_dim = 3 * 32 * 32 self.mat_dim = matrix_h * matrix_k self.encoder = nn.Sequential( nn.Linear(self.img_dim, 512), nn.GELU(), nn.Linear(512, 512), nn.GELU(), nn.Linear(512, self.mat_dim), ) self.decoder = nn.Sequential( nn.Linear(self.mat_dim, 512), nn.GELU(), nn.Linear(512, 512), nn.GELU(), nn.Linear(512, self.img_dim), ) def encode(self, images): B = images.shape[0] M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_h, self.matrix_k) U, S, Vt = torch.linalg.svd(M, full_matrices=False) k = self.keep_k return { 'U': U[:, :, :k], 'S': S[:, :k], 'Vt': Vt[:, :k, :], 'S_full': S, 'M': M, } def decode_from_svd(self, U, S, Vt): B = U.shape[0] M_hat = torch.bmm(U * S.unsqueeze(1), Vt) return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32) def forward(self, images): svd = self.encode(images) recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt']) return {'recon': recon, 'svd': svd} @staticmethod def effective_rank(S): p = S / (S.sum(-1, keepdim=True) + 1e-8) p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() @staticmethod def energy_ratio(S_kept, S_full): return (S_kept ** 2).sum(-1) / ((S_full ** 2).sum(-1) + 1e-8) @staticmethod def spectral_concentration_loss(S_full, keep_k): tail = (S_full[:, keep_k:] ** 2).sum(-1) head = (S_full[:, :keep_k] ** 2).sum(-1) return (tail / (head + 1e-8)).mean() # ── Training ── def train(epochs=50, lr=1e-3, keep_k=8, conc_weight=0.5, device='cuda'): device = torch.device(device if torch.cuda.is_available() else 'cpu') train_loader, test_loader = get_cifar10(batch_size=256) model = SVAE(matrix_h=64, matrix_k=64, keep_k=keep_k).to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) total_params = sum(p.numel() for p in model.parameters()) max_rank = min(model.matrix_h, model.matrix_k) print(f"SVAE v1 — Clean SVD Autoencoder") print(f" Matrix: ({model.matrix_h}, {model.matrix_k}) → max rank {max_rank}, keep {keep_k}") print(f" Compression: {model.img_dim} → {keep_k} singular values ({model.img_dim // keep_k}:1)") print(f" Params: {total_params:,}") print(f" Device: {device}") print("=" * 85) print(f"{'ep':>3} | {'loss':>7} {'recon':>7} {'conc':>7} | " f"{'t_recon':>7} | " f"{'erank':>6} {'energy':>6} {'S0':>7} {'Sk-1':>7} {'tail':>7}") print("-" * 85) for epoch in range(1, epochs + 1): model.train() total_loss, total_recon, n = 0, 0, 0 for images, labels in train_loader: images = images.to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) conc_loss = model.spectral_concentration_loss(out['svd']['S_full'], keep_k) loss = recon_loss + conc_weight * conc_loss loss.backward() opt.step() total_loss += loss.item() * len(images) total_recon += recon_loss.item() * len(images) n += len(images) sched.step() if epoch % 2 == 0 or epoch <= 3: model.eval() test_recon, test_n = 0, 0 test_erank, test_energy = 0, 0 test_S = None nb = 0 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) test_recon += F.mse_loss(out['recon'], images).item() * len(images) test_n += len(images) test_erank += model.effective_rank(out['svd']['S_full']).mean().item() test_energy += model.energy_ratio(out['svd']['S'], out['svd']['S_full']).mean().item() if test_S is None: test_S = out['svd']['S_full'].mean(0).cpu() else: test_S += out['svd']['S_full'].mean(0).cpu() nb += 1 test_erank /= nb test_energy /= nb test_S /= nb print(f"{epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} " f"{conc_loss.item():7.4f} | " f"{test_recon/test_n:7.4f} | " f"{test_erank:6.2f} {test_energy:6.3f} " f"{test_S[0]:7.3f} {test_S[keep_k-1]:7.3f} {test_S[keep_k]:7.4f}") # ── Final Analysis ── print() print("=" * 85) print("FINAL ANALYSIS") print("=" * 85) model.eval() all_S_full, all_S_kept, all_erank, all_energy = [], [], [], [] all_recon_err, all_labels = [], [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) all_S_full.append(out['svd']['S_full'].cpu()) all_S_kept.append(out['svd']['S'].cpu()) all_erank.append(model.effective_rank(out['svd']['S_full']).cpu()) all_energy.append(model.energy_ratio(out['svd']['S'], out['svd']['S_full']).cpu()) all_recon_err.append( F.mse_loss(out['recon'], images, reduction='none') .mean(dim=(1, 2, 3)).cpu()) all_labels.append(labels.cpu()) all_S_full = torch.cat(all_S_full) all_S_kept = torch.cat(all_S_kept) all_erank = torch.cat(all_erank) all_energy = torch.cat(all_energy) all_recon_err = torch.cat(all_recon_err) all_labels = torch.cat(all_labels) print(f"\n Bottleneck: {keep_k} / {min(model.matrix_h, model.matrix_k)} singular values") print(f" Energy captured: {all_energy.mean():.3f} ± {all_energy.std():.3f}") print(f" Effective rank (full): {all_erank.mean():.2f} ± {all_erank.std():.2f}") print(f" Recon MSE: {all_recon_err.mean():.6f} ± {all_recon_err.std():.6f}") # Singular value profile S_mean = all_S_full.mean(0) total_energy = (S_mean ** 2).sum() print(f"\n Singular value profile:") cumulative = 0 for i in range(min(32, len(S_mean))): e = (S_mean[i] ** 2).item() cumulative += e pct = cumulative / total_energy * 100 bar = "█" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8)) marker = " ← k" if i == keep_k - 1 else (" ┄ tail" if i == keep_k else "") print(f" S[{i:2d}]: {S_mean[i]:8.3f} cum_energy={pct:5.1f}% {bar}{marker}") # Per-class cifar_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] print(f"\n Per-class:") print(f" {'class':>6} {'erank':>6} {'energy':>6} {'recon':>8} {'S0':>7} {'S1':>7}") for c in range(10): mask = all_labels == c er = all_erank[mask].mean().item() en = all_energy[mask].mean().item() rc = all_recon_err[mask].mean().item() s0 = all_S_full[mask, 0].mean().item() s1 = all_S_full[mask, 1].mean().item() print(f" {cifar_names[c]:>6} {er:6.2f} {en:6.3f} {rc:8.6f} {s0:7.3f} {s1:7.3f}") # Recon by energy quartile print(f"\n Recon quality by energy quartile:") q25, q50, q75 = all_energy.quantile(torch.tensor([0.25, 0.50, 0.75])) for label, lo, hi in [("Q1 (low energy)", 0, q25), ("Q2", q25, q50), ("Q3", q50, q75), ("Q4 (high energy)", q75, 1.1)]: mask = (all_energy >= lo) & (all_energy < hi) if mask.sum() > 0: recon = all_recon_err[mask].mean().item() er = all_erank[mask].mean().item() print(f" {label:>17}: n={mask.sum():5d} recon={recon:.6f} erank={er:.2f}") # ── Save reconstruction grid ── print(f"\n Saving reconstruction grid...") import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # CIFAR-10 denormalization mean = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) std = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) model.eval() with torch.no_grad(): # Grab one batch images, labels = next(iter(test_loader)) images = images.to(device) out = model(images) # Pick 2 samples per class = 20 images selected_idx = [] for c in range(10): class_idx = (labels == c).nonzero(as_tuple=True)[0] selected_idx.extend(class_idx[:2].tolist()) orig = images[selected_idx] recon_full = out['recon'][selected_idx] # Progressive reconstructions: 1, 4, 8, 16 modes U = out['svd']['U'][selected_idx] S = out['svd']['S'][selected_idx] Vt = out['svd']['Vt'][selected_idx] mode_counts = [1, 4, 8, keep_k] # deduplicate if keep_k is already in the list mode_counts = list(dict.fromkeys(mode_counts)) prog_recons = [] for n_modes in mode_counts: n_modes = min(n_modes, S.shape[1]) U_n = U[:, :, :n_modes] S_n = S[:, :n_modes] Vt_n = Vt[:, :n_modes, :] r = model.decode_from_svd(U_n, S_n, Vt_n) prog_recons.append(r) def denorm(t): return (t * std + mean).clamp(0, 1).cpu() n_samples = len(selected_idx) n_cols = 2 + len(mode_counts) # original + progressives + error fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5)) col_titles = ['Original'] + [f'{m} mode{"s" if m > 1 else ""}' for m in mode_counts] + ['|Error|×5'] for i in range(n_samples): # Original img_orig = denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy() axes[i, 0].imshow(img_orig) # Progressive for j, r in enumerate(prog_recons): img_r = denorm(r[i:i+1])[0].permute(1, 2, 0).numpy() axes[i, j+1].imshow(img_r) # Error map (amplified 5×) err_col = 1 + len(prog_recons) diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5 diff = diff.clamp(0, 1)[0].permute(1, 2, 0).numpy() axes[i, err_col].imshow(diff) # Class label c = labels[selected_idx[i]].item() axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35) for j, title in enumerate(col_titles): axes[0, j].set_title(title, fontsize=8) for ax in axes.flat: ax.axis('off') plt.tight_layout() plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight') print(f" Saved to /content/svae_recon_grid.png") try: plt.show() except: pass plt.close() if __name__ == "__main__": train(keep_k=16)