| """ |
| SVAE β SVD Autoencoder with Geometric Attractors |
| =================================================== |
| A matrix-valued autoencoder where the latent space is a (V, D) matrix |
| decomposed by SVD. Rows are normalized to S^(D-1), making the geometric |
| structure architectural rather than loss-dependent. |
| |
| Two key mechanisms: |
| 1. Sphere normalization: F.normalize(M, dim=-1) constrains rows to unit |
| vectors on S^(D-1). This bounds the Gram matrix, eliminates training |
| instabilities, and makes the CV a structural property of (V, D). |
| 2. Soft hand: An oscillatory counterweight that boosts reconstruction |
| gradients when geometry is near target, and penalizes CV drift when |
| geometry is far from target. Provides positive momentum, not just penalty. |
| |
| Architecture: Image β MLP β M β β^(VΓD) β normalize β SVD β MLP β Recon |
| |
| Repository: AbstractEyes/geolip-core |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as T |
| import math |
| import time |
|
|
| |
|
|
| try: |
| from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N |
| HAS_FL = True |
| except ImportError: |
| HAS_FL = False |
|
|
|
|
| def gram_eigh_svd_fp64(A): |
| """Thin SVD via Gram matrix + eigh, computed entirely in fp64. |
| |
| fp64 is essential: Gram entries scale as SβΒ², and fp32 (~7 digits) |
| causes catastrophic collapses when the condition number exceeds ~100. |
| fp64 (~15 digits) eliminates this failure mode entirely. |
| |
| Args: |
| A: (B, M, N) tensor, M >= N |
| Returns: |
| U (B,M,N), S (B,N), Vh (B,N,N) β singular values descending. |
| """ |
| orig_dtype = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| A_d = A.double() |
| G = torch.bmm(A_d.transpose(1, 2), A_d) |
| eigenvalues, V = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| V = V.flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) |
|
|
|
|
| def svd_fp64(A): |
| """Auto-dispatch SVD with fp64 internals. |
| |
| N <= 12 + FLEigh available: Gram in fp64, FL eigh (compilable). |
| N > 12 or CPU: Gram + torch.linalg.eigh in fp64. |
| Triton bypassed β fp32-only hardware, incompatible with fp64. |
| """ |
| B, M, N = A.shape |
| if HAS_FL and N <= _FL_MAX_N and A.is_cuda: |
| orig_dtype = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| A_d = A.double() |
| G = torch.bmm(A_d.transpose(1, 2), A_d) |
| eigenvalues, V = FLEigh()(G.float()) |
| eigenvalues = eigenvalues.double().flip(-1) |
| V = V.double().flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) |
| else: |
| return gram_eigh_svd_fp64(A) |
|
|
|
|
| |
|
|
| def cayley_menger_vol2(points): |
| """Squared simplex volume via Cayley-Menger determinant, in fp64. |
| Args: points (B, N, D) β B simplices, each with N vertices in D dims. |
| Returns: (B,) squared volumes. |
| """ |
| B, N, D = points.shape |
| pts = points.double() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_of(emb, n_samples=200): |
| """CV of pentachoron volumes for a single embedding matrix. |
| Measures geometric regularity: low CV = regular, high CV = irregular. |
| Args: emb (V, D) tensor. |
| Returns: float CV value, or 0.0 if insufficient data. |
| """ |
| if emb.dim() != 2 or emb.shape[0] < 5: |
| return 0.0 |
| N, D = emb.shape |
| pool = min(N, 512) |
| indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) |
| vol2 = cayley_menger_vol2(emb[:pool][indices]) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return 0.0 |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| |
|
|
| def get_cifar10(batch_size=256): |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
| train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) |
| test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) |
| return train_loader, test_loader, 3 * 32 * 32, 10, \ |
| ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
|
|
| def get_tiny_imagenet(batch_size=256): |
| """TinyImageNet via HuggingFace: 200 classes, 64x64, 100K train / 10K val.""" |
| from datasets import load_dataset |
|
|
| ds = load_dataset('zh-plus/tiny-imagenet') |
|
|
| mean = (0.4802, 0.4481, 0.3975) |
| std = (0.2770, 0.2691, 0.2821) |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize(mean, std), |
| ]) |
|
|
| class HFImageDataset(torch.utils.data.Dataset): |
| def __init__(self, hf_split, transform): |
| self.data = hf_split |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| img = item['image'] |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| return self.transform(img), item['label'] |
|
|
| train_ds = HFImageDataset(ds['train'], transform) |
| val_ds = HFImageDataset(ds['valid'], transform) |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) |
| val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2) |
|
|
| class_names = [f'c{i:03d}' for i in range(200)] |
| return train_loader, val_loader, 3 * 64 * 64, 200, class_names |
|
|
|
|
| |
|
|
| class SVAE(nn.Module): |
| """SVD Autoencoder with sphere-normalized matrix latent space. |
| |
| The encoder produces a (V, D) matrix whose rows are normalized to S^(D-1). |
| The SVD decomposes alignment structure (U, V) from spectral magnitudes (S). |
| The decoder reconstructs from the full SVD: MΜ = UΞ£Vα΅. |
| |
| Args: |
| matrix_v: Number of rows V (vocabulary size / overcomplete factor) |
| D: Embedding dimension (number of singular values) |
| img_dim: Flattened image dimension (3*H*W) |
| hidden: Hidden layer width (auto-scaled if None) |
| """ |
| def __init__(self, matrix_v=96, D=24, img_dim=3072, hidden=None): |
| super().__init__() |
| self.matrix_v = matrix_v |
| self.D = D |
| self.img_dim = img_dim |
| self.mat_dim = matrix_v * D |
| h = hidden or max(512, min(2048, img_dim // 4)) |
|
|
| self.encoder = nn.Sequential( |
| nn.Linear(self.img_dim, h), |
| nn.GELU(), |
| nn.Linear(h, h), |
| nn.GELU(), |
| nn.Linear(h, self.mat_dim), |
| ) |
| self.decoder = nn.Sequential( |
| nn.Linear(self.mat_dim, h), |
| nn.GELU(), |
| nn.Linear(h, h), |
| nn.GELU(), |
| nn.Linear(h, self.img_dim), |
| ) |
| nn.init.orthogonal_(self.encoder[-1].weight) |
|
|
| def encode(self, images): |
| B = images.shape[0] |
| M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_v, self.D) |
| M = F.normalize(M, dim=-1) |
| U, S, Vh = svd_fp64(M) |
| return {'U': U, 'S': S, 'Vt': Vh, 'M': M} |
|
|
| def decode_from_svd(self, U, S, Vt): |
| B = U.shape[0] |
| M_hat = torch.bmm(U * S.unsqueeze(1), Vt) |
| flat = self.decoder(M_hat.reshape(B, -1)) |
| |
| hw = self.img_dim // 3 |
| h = int(hw ** 0.5) |
| return flat.reshape(B, 3, h, h) |
|
|
| def forward(self, images): |
| svd = self.encode(images) |
| recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt']) |
| return {'recon': recon, 'svd': svd} |
|
|
| @staticmethod |
| def effective_rank(S): |
| """Shannon entropy effective rank of singular value spectrum.""" |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| return (-(p * p.log()).sum(-1)).exp() |
|
|
|
|
| |
|
|
| def train(epochs=100, lr=1e-3, V=256, D=24, target_cv=0.125, |
| cv_weight=0.3, boost=0.5, sigma=0.15, |
| dataset='cifar10', device='cuda'): |
| """Train the SVAE with sphere normalization + soft hand. |
| |
| Args: |
| epochs: Training epochs |
| lr: Learning rate for Adam |
| V: Matrix rows (vocabulary size) |
| D: Embedding dimension |
| target_cv: CV attractor target for soft hand |
| cv_weight: Maximum CV penalty weight (far from target) |
| boost: Maximum reconstruction boost factor (near target) |
| sigma: Gaussian transition width for proximity |
| dataset: 'cifar10' or 'tiny_imagenet' |
| device: Training device |
| """ |
| device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
| if dataset == 'tiny_imagenet': |
| train_loader, test_loader, img_dim, n_classes, class_names = get_tiny_imagenet(batch_size=256) |
| else: |
| train_loader, test_loader, img_dim, n_classes, class_names = get_cifar10(batch_size=256) |
|
|
| img_h = int((img_dim // 3) ** 0.5) |
|
|
| model = SVAE(matrix_v=V, D=D, img_dim=img_dim).to(device) |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
|
|
| |
| svd_backend = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})" |
| print(f"Using geolip-core SVD ({svd_backend})") |
| print(f"SVAE - V={V}, D={D}, rows on S^{D-1} + soft hand") |
| print(f" Dataset: {dataset} ({img_h}x{img_h}, {n_classes} classes)") |
| print(f" Matrix: ({V}, {D}) = {V*D} elements, rows normalized") |
| print(f" SVD: fp64 Gram+eigh") |
| print(f" Sphere: rows on S^{D-1} (structural geometry)") |
| print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far") |
| print(f" Params: {total_params:,}") |
| print("=" * 90) |
| print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | " |
| f"{'t_rec':>7} | " |
| f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " |
| f"{'row_cv':>7} {'prox':>5} {'rw':>5}") |
| print("-" * 90) |
|
|
| |
| for epoch in range(1, epochs + 1): |
| model.train() |
| total_loss, total_recon, n = 0, 0, 0 |
| last_cv = target_cv |
| last_prox = 1.0 |
| recon_w = 1.0 + boost |
| t0 = time.time() |
|
|
| for batch_idx, (images, labels) in enumerate(train_loader): |
| images = images.to(device) |
| opt.zero_grad() |
| out = model(images) |
| recon_loss = F.mse_loss(out['recon'], images) |
|
|
| |
| with torch.no_grad(): |
| if batch_idx % 10 == 0: |
| current_cv = cv_of(out['svd']['M'][0]) |
| if current_cv > 0: |
| last_cv = current_cv |
| delta = last_cv - target_cv |
| last_prox = math.exp(-delta**2 / (2 * sigma**2)) |
|
|
| |
| recon_w = 1.0 + boost * last_prox |
| cv_pen = cv_weight * (1.0 - last_prox) |
| cv_l = (last_cv - target_cv) ** 2 |
|
|
| loss = recon_w * recon_loss + cv_pen * cv_l |
| loss.backward() |
| opt.step() |
|
|
| total_loss += loss.item() * len(images) |
| total_recon += recon_loss.item() * len(images) |
| n += len(images) |
|
|
| sched.step() |
| epoch_time = time.time() - t0 |
|
|
| |
| if epoch % 2 == 0 or epoch <= 3: |
| model.eval() |
| test_recon, test_n = 0, 0 |
| test_S, test_erank = None, 0 |
| row_cvs = [] |
| nb = 0 |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| out = model(images) |
| test_recon += F.mse_loss(out['recon'], images).item() * len(images) |
| test_n += len(images) |
| test_erank += model.effective_rank(out['svd']['S']).mean().item() |
| if nb < 5: |
| for b in range(min(4, len(images))): |
| row_cvs.append(cv_of(out['svd']['M'][b])) |
| if test_S is None: |
| test_S = out['svd']['S'].mean(0).cpu() |
| else: |
| test_S += out['svd']['S'].mean(0).cpu() |
| nb += 1 |
|
|
| test_erank /= nb |
| test_S /= nb |
| ratio = (test_S[0] / (test_S[-1] + 1e-8)).item() |
| mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0 |
|
|
| print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | " |
| f"{test_recon/test_n:7.4f} | " |
| f"{test_S[0]:6.3f} {test_S[-1]:6.3f} {ratio:5.2f} " |
| f"{test_erank:5.2f} | " |
| f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f}") |
|
|
| |
| print() |
| print("=" * 85) |
| print("FINAL ANALYSIS") |
| print("=" * 85) |
|
|
| model.eval() |
| all_S, all_recon_err, all_labels = [], [], [] |
| all_row_cvs = [] |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| out = model(images) |
| all_S.append(out['svd']['S'].cpu()) |
| all_recon_err.append( |
| F.mse_loss(out['recon'], images, reduction='none') |
| .mean(dim=(1, 2, 3)).cpu()) |
| all_labels.append(labels.cpu()) |
| for b in range(min(8, len(images))): |
| all_row_cvs.append(cv_of(out['svd']['M'][b])) |
|
|
| all_S = torch.cat(all_S) |
| all_recon_err = torch.cat(all_recon_err) |
| all_labels = torch.cat(all_labels) |
| erank = model.effective_rank(all_S) |
| mean_cv = sum(all_row_cvs) / len(all_row_cvs) |
|
|
| print(f"\n V={V}, D={D}, rows on S^{D-1}") |
| print(f" Target CV: {target_cv}") |
| print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}") |
| print(f" Effective rank: {erank.mean():.2f} +/- {erank.std():.2f}") |
| print(f" Row CV: {mean_cv:.4f}") |
|
|
| S_mean = all_S.mean(0) |
| total_energy = (S_mean ** 2).sum() |
| print(f"\n Singular value profile:") |
| cumulative = 0 |
| for i in range(len(S_mean)): |
| e = (S_mean[i] ** 2).item() |
| cumulative += e |
| pct = cumulative / total_energy * 100 |
| bar = "#" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8)) |
| print(f" S[{i:2d}]: {S_mean[i]:8.4f} cum={pct:5.1f}% {bar}") |
|
|
| |
| show_classes = min(n_classes, 20) |
| print(f"\n Per-class (showing {show_classes}/{n_classes}):") |
| print(f" {'cls':>8} {'recon':>8} {'erank':>6} {'S0':>7} {'SD':>7} {'ratio':>6}") |
| for c in range(show_classes): |
| mask = all_labels == c |
| if mask.sum() == 0: |
| continue |
| rc = all_recon_err[mask].mean().item() |
| er = erank[mask].mean().item() |
| s0 = all_S[mask, 0].mean().item() |
| sd = all_S[mask, -1].mean().item() |
| r = s0 / (sd + 1e-8) |
| name = class_names[c] if c < len(class_names) else f'cls_{c}' |
| print(f" {name:>8} {rc:8.6f} {er:6.2f} {s0:7.4f} {sd:7.4f} {r:6.2f}") |
|
|
| |
| print(f"\n Saving reconstruction grid...") |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| if dataset == 'tiny_imagenet': |
| mean_t = torch.tensor([0.4802, 0.4481, 0.3975]).reshape(1, 3, 1, 1).to(device) |
| std_t = torch.tensor([0.2770, 0.2691, 0.2821]).reshape(1, 3, 1, 1).to(device) |
| else: |
| mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) |
| std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) |
|
|
| |
| grid_classes = min(n_classes, 10) |
| model.eval() |
| with torch.no_grad(): |
| images, labels = next(iter(test_loader)) |
| images = images.to(device) |
| out = model(images) |
|
|
| selected_idx = [] |
| for c in range(grid_classes): |
| class_idx = (labels == c).nonzero(as_tuple=True)[0] |
| selected_idx.extend(class_idx[:2].tolist()) |
|
|
| if not selected_idx: |
| selected_idx = list(range(min(20, len(images)))) |
|
|
| orig = images[selected_idx] |
| U = out['svd']['U'][selected_idx] |
| S = out['svd']['S'][selected_idx] |
| Vt = out['svd']['Vt'][selected_idx] |
|
|
| mode_counts = [1, 4, 8, 16, D] |
| prog_recons = [] |
| for nm in mode_counts: |
| r = model.decode_from_svd(U[:, :, :nm], S[:, :nm], Vt[:, :nm, :]) |
| prog_recons.append(r) |
|
|
| def denorm(t): |
| return (t * std_t + mean_t).clamp(0, 1).cpu() |
|
|
| n_samples = len(selected_idx) |
| n_cols = 2 + len(mode_counts) |
| fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5)) |
| col_titles = ['Original'] + [f'{m} modes' for m in mode_counts] + ['|Err|x5'] |
|
|
| for i in range(n_samples): |
| axes[i, 0].imshow(denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy()) |
| for j, r in enumerate(prog_recons): |
| axes[i, j+1].imshow(denorm(r[i:i+1])[0].permute(1, 2, 0).numpy()) |
| err_col = 1 + len(prog_recons) |
| diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5 |
| axes[i, err_col].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()) |
| c = labels[selected_idx[i]].item() |
| name = class_names[c] if c < len(class_names) else f'{c}' |
| axes[i, 0].set_ylabel(name, fontsize=7, rotation=0, labelpad=40) |
|
|
| for j, title in enumerate(col_titles): |
| axes[0, j].set_title(title, fontsize=8) |
| for ax in axes.flat: |
| ax.axis('off') |
|
|
| plt.tight_layout() |
| plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight') |
| print(f" Saved to /content/svae_recon_grid.png") |
| try: |
| plt.show() |
| except: |
| pass |
| plt.close() |
|
|
|
|
| if __name__ == "__main__": |
| train(epochs=200, V=256, D=24, target_cv=0.45, dataset='tiny_imagenet') |