""" SVAE - V=48, D=24 (recon + CV) ================================ Matrix (48, 24): 48 rows in D=24 space. Validated CV=0.3668 for this V,D combo. No KL. No reparameterization. Just: - Reconstruction loss (MSE) - Row CV loss (encourage toward 0.3668) geolip-core Gram+eigh SVD on 24x24 Gram matrix. pip install "git+https://github.com/AbstractEyes/geolip-core.git" """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as T import math import time try: from geolip_core.linalg import svd as geolip_svd HAS_GEOLIP = True print("Using geolip-core SVD (Gram + eigh)") except ImportError: HAS_GEOLIP = False print("geolip-core not found, fallback to torch.svd_lowrank") # -- CM monitoring + loss -- def cayley_menger_vol2(points): B, N, D = points.shape gram = torch.bmm(points, points.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) def cv_of(emb, n_samples=200): if emb.dim() != 2 or emb.shape[0] < 5: return 0.0 N, D = emb.shape pool = min(N, 512) indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) vol2 = cayley_menger_vol2(emb[:pool][indices]) valid = vol2 > 1e-20 if valid.sum() < 10: return 0.0 vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # -- Data -- def get_cifar10(batch_size=256): transform = T.Compose([ T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, test_loader # -- SVAE -- class SVAE(nn.Module): def __init__(self, matrix_v=48, D=24, target_cv=0.3668): super().__init__() self.matrix_v = matrix_v self.D = D self.target_cv = target_cv self.img_dim = 3 * 32 * 32 self.mat_dim = matrix_v * D self.encoder = nn.Sequential( nn.Linear(self.img_dim, 512), nn.GELU(), nn.Linear(512, 512), nn.GELU(), nn.Linear(512, self.mat_dim), ) self.decoder = nn.Sequential( nn.Linear(self.mat_dim, 512), nn.GELU(), nn.Linear(512, 512), nn.GELU(), nn.Linear(512, self.img_dim), ) nn.init.orthogonal_(self.encoder[-1].weight) def encode(self, images): B = images.shape[0] M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_v, self.D) if HAS_GEOLIP: U, S, Vh = geolip_svd(M) else: U, S, V = torch.svd_lowrank(M.float(), q=self.D, niter=4) Vh = V.transpose(1, 2) return {'U': U, 'S': S, 'Vt': Vh, 'M': M} def decode_from_svd(self, U, S, Vt): B = U.shape[0] M_hat = torch.bmm(U * S.unsqueeze(1), Vt) return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32) def row_cv_loss(self, M): B = M.shape[0] n_sample = min(B, 8) loss = torch.tensor(0.0, device=M.device) for i in range(n_sample): cv = cv_of(M[i]) if cv > 0: loss = loss + (cv - self.target_cv) ** 2 return loss / n_sample def forward(self, images): svd = self.encode(images) recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt']) cv_l = self.row_cv_loss(svd['M']) return {'recon': recon, 'svd': svd, 'cv_loss': cv_l} @staticmethod def effective_rank(S): p = S / (S.sum(-1, keepdim=True) + 1e-8) p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() # -- Training -- def train(epochs=100, lr=1e-3, cv_weight=0.1, device='cuda'): device = torch.device(device if torch.cuda.is_available() else 'cpu') train_loader, test_loader = get_cifar10(batch_size=256) V, D = 96, 24 target_cv = 0.2992 model = SVAE(matrix_v=V, D=D, target_cv=target_cv).to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) total_params = sum(p.numel() for p in model.parameters()) print(f"SVAE - V={V}, D={D} (Validated: CV={target_cv})") print(f" Matrix: ({V}, {D}) = {V*D} elements") print(f" SVD: geolip-core Gram+eigh") print(f" Losses: recon + CV(w={cv_weight}, target={target_cv})") print(f" Params: {total_params:,}") print("=" * 85) print(f"{'ep':>3} | {'loss':>7} {'recon':>7} {'cv_l':>7} {'t/ep':>5} | " f"{'t_rec':>7} | " f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " f"{'row_cv':>7}") print("-" * 85) for epoch in range(1, epochs + 1): model.train() total_loss, total_recon, n = 0, 0, 0 t0 = time.time() for images, labels in train_loader: images = images.to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) cv_l = out['cv_loss'] loss = recon_loss + cv_weight * cv_l loss.backward() opt.step() total_loss += loss.item() * len(images) total_recon += recon_loss.item() * len(images) n += len(images) sched.step() epoch_time = time.time() - t0 if epoch % 2 == 0 or epoch <= 3: model.eval() test_recon, test_n = 0, 0 test_S, test_erank = None, 0 row_cvs = [] nb = 0 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) test_recon += F.mse_loss(out['recon'], images).item() * len(images) test_n += len(images) test_erank += model.effective_rank(out['svd']['S']).mean().item() if nb < 5: for b in range(min(4, len(images))): row_cvs.append(cv_of(out['svd']['M'][b])) if test_S is None: test_S = out['svd']['S'].mean(0).cpu() else: test_S += out['svd']['S'].mean(0).cpu() nb += 1 test_erank /= nb test_S /= nb ratio = (test_S[0] / (test_S[-1] + 1e-8)).item() mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0 print(f"{epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} " f"{cv_l.item():7.4f} {epoch_time:5.1f} | " f"{test_recon/test_n:7.4f} | " f"{test_S[0]:6.2f} {test_S[-1]:6.3f} {ratio:5.2f} " f"{test_erank:5.2f} | " f"{mean_cv:7.4f}") # -- Final Analysis -- print() print("=" * 85) print("FINAL ANALYSIS") print("=" * 85) model.eval() all_S, all_recon_err, all_labels = [], [], [] all_row_cvs = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) out = model(images) all_S.append(out['svd']['S'].cpu()) all_recon_err.append( F.mse_loss(out['recon'], images, reduction='none') .mean(dim=(1, 2, 3)).cpu()) all_labels.append(labels.cpu()) for b in range(min(8, len(images))): all_row_cvs.append(cv_of(out['svd']['M'][b])) all_S = torch.cat(all_S) all_recon_err = torch.cat(all_recon_err) all_labels = torch.cat(all_labels) erank = model.effective_rank(all_S) mean_cv = sum(all_row_cvs) / len(all_row_cvs) print(f"\n V={model.matrix_v}, D={model.D}") print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}") print(f" Effective rank: {erank.mean():.2f} +/- {erank.std():.2f}") print(f" Row CV: {mean_cv:.4f} (target: {model.target_cv}, delta: {abs(mean_cv - model.target_cv):.4f})") # Spectrum S_mean = all_S.mean(0) total_energy = (S_mean ** 2).sum() print(f"\n Singular value profile:") cumulative = 0 for i in range(len(S_mean)): e = (S_mean[i] ** 2).item() cumulative += e pct = cumulative / total_energy * 100 bar = "#" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8)) print(f" S[{i:2d}]: {S_mean[i]:8.3f} cum={pct:5.1f}% {bar}") # Per-class cifar_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] print(f"\n Per-class:") print(f" {'cls':>6} {'recon':>8} {'erank':>6} {'S0':>7} {'SD':>7} {'ratio':>6}") for c in range(10): mask = all_labels == c rc = all_recon_err[mask].mean().item() er = erank[mask].mean().item() s0 = all_S[mask, 0].mean().item() sd = all_S[mask, -1].mean().item() r = s0 / (sd + 1e-8) print(f" {cifar_names[c]:>6} {rc:8.6f} {er:6.2f} {s0:7.3f} {sd:7.3f} {r:6.2f}") # -- Recon grid -- print(f"\n Saving reconstruction grid...") import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) model.eval() with torch.no_grad(): images, labels = next(iter(test_loader)) images = images.to(device) out = model(images) selected_idx = [] for c in range(10): class_idx = (labels == c).nonzero(as_tuple=True)[0] selected_idx.extend(class_idx[:2].tolist()) orig = images[selected_idx] U = out['svd']['U'][selected_idx] S = out['svd']['S'][selected_idx] Vt = out['svd']['Vt'][selected_idx] mode_counts = [1, 4, 8, 16, D] mode_counts = list(dict.fromkeys([m for m in mode_counts if m <= D])) prog_recons = [] for nm in mode_counts: r = model.decode_from_svd(U[:, :, :nm], S[:, :nm], Vt[:, :nm, :]) prog_recons.append(r) def denorm(t): return (t * std_t + mean_t).clamp(0, 1).cpu() n_samples = len(selected_idx) n_cols = 2 + len(mode_counts) fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5)) col_titles = ['Original'] + [f'{m} modes' for m in mode_counts] + ['|Err|x5'] for i in range(n_samples): axes[i, 0].imshow(denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy()) for j, r in enumerate(prog_recons): axes[i, j+1].imshow(denorm(r[i:i+1])[0].permute(1, 2, 0).numpy()) err_col = 1 + len(prog_recons) diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5 axes[i, err_col].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()) c = labels[selected_idx[i]].item() axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35) for j, title in enumerate(col_titles): axes[0, j].set_title(title, fontsize=8) for ax in axes.flat: ax.axis('off') plt.tight_layout() plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight') print(f" Saved to /content/svae_recon_grid.png") try: plt.show() except: pass plt.close() if __name__ == "__main__": train()