| """ |
| SVAE v1 β Clean SVD Autoencoder |
| ================================== |
| The version that works. Uses geolip-core's optimized SVD kernel. |
| |
| Image β encoder MLP β RECTANGULAR matrix (H, k) β SVD β decode β image |
| |
| Rectangular matrix means SVD returns exactly k values β no truncation. |
| For kβ€12, geolip-core uses the FL eigh kernel (fully compilable, zero graph breaks). |
| For k>12, Gram + torch.linalg.eigh (still better backward than linalg.svd). |
| |
| pip install "git+https://github.com/AbstractEyes/geolip-core.git" |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as T |
| import math |
|
|
| try: |
| from geolip_core.linalg import svd as geolip_svd |
| HAS_GEOLIP = True |
| print("Using geolip-core SVD (Gram + eigh, compilable for kβ€12)") |
| except ImportError: |
| HAS_GEOLIP = False |
| print("geolip-core not found, using torch.svd_lowrank fallback") |
| print('Install: pip install "git+https://github.com/AbstractEyes/geolip-core.git"') |
|
|
|
|
| |
|
|
| def get_cifar10(batch_size=256): |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
| train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) |
| test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) |
| return train_loader, test_loader |
|
|
|
|
| |
|
|
| class SVAE(nn.Module): |
| def __init__(self, matrix_h=64, keep_k=8): |
| super().__init__() |
| self.matrix_h = matrix_h |
| self.matrix_k = keep_k |
| self.keep_k = keep_k |
| self.img_dim = 3 * 32 * 32 |
| self.mat_dim = matrix_h * keep_k |
|
|
| self.encoder = nn.Sequential( |
| nn.Linear(self.img_dim, 512), |
| nn.GELU(), |
| nn.Linear(512, 512), |
| nn.GELU(), |
| nn.Linear(512, self.mat_dim), |
| ) |
| self.decoder = nn.Sequential( |
| nn.Linear(self.mat_dim, 512), |
| nn.GELU(), |
| nn.Linear(512, 512), |
| nn.GELU(), |
| nn.Linear(512, self.img_dim), |
| ) |
|
|
| def encode(self, images): |
| B = images.shape[0] |
| M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_h, self.matrix_k) |
|
|
| if HAS_GEOLIP: |
| |
| U, S, Vh = geolip_svd(M) |
| else: |
| |
| U, S, V = torch.svd_lowrank(M, q=self.keep_k) |
| Vh = V.transpose(1, 2) |
|
|
| return { |
| 'U': U, 'S': S, 'Vt': Vh, |
| 'M': M, |
| } |
|
|
| def decode_from_svd(self, U, S, Vt): |
| B = U.shape[0] |
| M_hat = torch.bmm(U * S.unsqueeze(1), Vt) |
| return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32) |
|
|
| def forward(self, images): |
| svd = self.encode(images) |
| recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt']) |
| return {'recon': recon, 'svd': svd} |
|
|
| @staticmethod |
| def effective_rank(S): |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| return (-(p * p.log()).sum(-1)).exp() |
|
|
|
|
| |
|
|
| def train(epochs=50, lr=1e-3, keep_k=16, device='cuda'): |
| device = torch.device(device if torch.cuda.is_available() else 'cpu') |
| train_loader, test_loader = get_cifar10(batch_size=256) |
|
|
| model = SVAE(matrix_h=64, keep_k=keep_k).to(device) |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"SVAE v1 β Clean SVD Autoencoder") |
| print(f" Matrix: ({model.matrix_h}, {model.matrix_k}) β rectangular, exact {keep_k} singular values") |
| print(f" SVD: {'geolip-core Gram+eigh' if HAS_GEOLIP else 'torch.svd_lowrank'}" |
| f"{' (FL compilable)' if HAS_GEOLIP and keep_k <= 12 else ''}") |
| print(f" Compression: {model.img_dim} β {keep_k} ({model.img_dim // keep_k}:1)") |
| print(f" Params: {total_params:,}") |
| print(f" Device: {device}") |
| print("=" * 70) |
| print(f"{'ep':>3} | {'loss':>7} {'recon':>7} | " |
| f"{'t_recon':>7} | " |
| f"{'S0':>7} {'S1':>7} {'Sk-1':>7}") |
| print("-" * 70) |
|
|
| for epoch in range(1, epochs + 1): |
| model.train() |
| total_loss, total_recon, n = 0, 0, 0 |
|
|
| for images, labels in train_loader: |
| images = images.to(device) |
| opt.zero_grad() |
| out = model(images) |
|
|
| recon_loss = F.mse_loss(out['recon'], images) |
| loss = recon_loss |
| loss.backward() |
| opt.step() |
|
|
| total_loss += loss.item() * len(images) |
| total_recon += recon_loss.item() * len(images) |
| n += len(images) |
|
|
| sched.step() |
|
|
| if epoch % 2 == 0 or epoch <= 3: |
| model.eval() |
| test_recon, test_n = 0, 0 |
| test_S = None |
| nb = 0 |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| out = model(images) |
| test_recon += F.mse_loss(out['recon'], images).item() * len(images) |
| test_n += len(images) |
| if test_S is None: |
| test_S = out['svd']['S'].mean(0).cpu() |
| else: |
| test_S += out['svd']['S'].mean(0).cpu() |
| nb += 1 |
|
|
| test_S /= nb |
|
|
| print(f"{epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} | " |
| f"{test_recon/test_n:7.4f} | " |
| f"{test_S[0]:7.3f} {test_S[1]:7.3f} {test_S[-1]:7.3f}") |
|
|
| |
| print() |
| print("=" * 85) |
| print("FINAL ANALYSIS") |
| print("=" * 85) |
|
|
| model.eval() |
| all_S, all_recon_err, all_labels = [], [], [] |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| out = model(images) |
| all_S.append(out['svd']['S'].cpu()) |
| all_recon_err.append( |
| F.mse_loss(out['recon'], images, reduction='none') |
| .mean(dim=(1, 2, 3)).cpu()) |
| all_labels.append(labels.cpu()) |
|
|
| all_S = torch.cat(all_S) |
| all_recon_err = torch.cat(all_recon_err) |
| all_labels = torch.cat(all_labels) |
|
|
| erank = model.effective_rank(all_S) |
|
|
| print(f"\n Bottleneck: {keep_k} singular values (truncated SVD)") |
| print(f" Effective rank: {erank.mean():.2f} Β± {erank.std():.2f}") |
| print(f" Recon MSE: {all_recon_err.mean():.6f} Β± {all_recon_err.std():.6f}") |
|
|
| |
| S_mean = all_S.mean(0) |
| total_energy = (S_mean ** 2).sum() |
| print(f"\n Singular value profile:") |
| cumulative = 0 |
| for i in range(len(S_mean)): |
| e = (S_mean[i] ** 2).item() |
| cumulative += e |
| pct = cumulative / total_energy * 100 |
| bar = "β" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8)) |
| print(f" S[{i:2d}]: {S_mean[i]:8.3f} cum_energy={pct:5.1f}% {bar}") |
|
|
| |
| cifar_names = ['plane', 'car', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
| print(f"\n Per-class:") |
| print(f" {'class':>6} {'recon':>8} {'S0':>7} {'S1':>7} {'erank':>6}") |
| for c in range(10): |
| mask = all_labels == c |
| rc = all_recon_err[mask].mean().item() |
| s0 = all_S[mask, 0].mean().item() |
| s1 = all_S[mask, 1].mean().item() |
| er = erank[mask].mean().item() |
| print(f" {cifar_names[c]:>6} {rc:8.6f} {s0:7.3f} {s1:7.3f} {er:6.2f}") |
|
|
| |
| print(f"\n Saving reconstruction grid...") |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| |
| mean = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) |
| std = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| |
| images, labels = next(iter(test_loader)) |
| images = images.to(device) |
| out = model(images) |
|
|
| |
| selected_idx = [] |
| for c in range(10): |
| class_idx = (labels == c).nonzero(as_tuple=True)[0] |
| selected_idx.extend(class_idx[:2].tolist()) |
|
|
| orig = images[selected_idx] |
| recon_full = out['recon'][selected_idx] |
|
|
| |
| U = out['svd']['U'][selected_idx] |
| S = out['svd']['S'][selected_idx] |
| Vt = out['svd']['Vt'][selected_idx] |
|
|
| mode_counts = [1, 4, 8, keep_k] |
| |
| mode_counts = list(dict.fromkeys(mode_counts)) |
| prog_recons = [] |
| for n_modes in mode_counts: |
| n_modes = min(n_modes, S.shape[1]) |
| U_n = U[:, :, :n_modes] |
| S_n = S[:, :n_modes] |
| Vt_n = Vt[:, :n_modes, :] |
| r = model.decode_from_svd(U_n, S_n, Vt_n) |
| prog_recons.append(r) |
|
|
| def denorm(t): |
| return (t * std + mean).clamp(0, 1).cpu() |
|
|
| n_samples = len(selected_idx) |
| n_cols = 2 + len(mode_counts) |
| fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5)) |
|
|
| col_titles = ['Original'] + [f'{m} mode{"s" if m > 1 else ""}' for m in mode_counts] + ['|Error|Γ5'] |
|
|
| for i in range(n_samples): |
| |
| img_orig = denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy() |
| axes[i, 0].imshow(img_orig) |
|
|
| |
| for j, r in enumerate(prog_recons): |
| img_r = denorm(r[i:i+1])[0].permute(1, 2, 0).numpy() |
| axes[i, j+1].imshow(img_r) |
|
|
| |
| err_col = 1 + len(prog_recons) |
| diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5 |
| diff = diff.clamp(0, 1)[0].permute(1, 2, 0).numpy() |
| axes[i, err_col].imshow(diff) |
|
|
| |
| c = labels[selected_idx[i]].item() |
| axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35) |
|
|
| for j, title in enumerate(col_titles): |
| axes[0, j].set_title(title, fontsize=8) |
|
|
| for ax in axes.flat: |
| ax.axis('off') |
|
|
| plt.tight_layout() |
| plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight') |
| print(f" Saved to /content/svae_recon_grid.png") |
| try: |
| plt.show() |
| except: |
| pass |
| plt.close() |
|
|
|
|
| if __name__ == "__main__": |
| train(keep_k=16) |