Rename 111m_proto_1024_v3_geometrically_cv_aligned.py to 5m_proto_1024_v3_geometrically_cv_aligned.py
d88d07c verified | """ | |
| SVAE β Structural Binding Constant | |
| ===================================== | |
| Matrix (V, 24): V rows in D=24 space. | |
| At D=24, CV β 0.29154 BY CONSTRUCTION β no loss needed. | |
| The sweep proved it: | |
| V=200, D=24 β CV=0.2914 | |
| V=1024, D=24 β CV=0.2916 | |
| V=1992, D=24 β CV=0.2911 | |
| V is irrelevant. D determines CV. | |
| The encoder produces a (V, 24) matrix. | |
| The rows ARE an embedding: V tokens in D=24 space. | |
| Their CV is ~0.29 by the dimensional law. | |
| The SVD decomposes this embedding into its spectral structure. | |
| The decoder reconstructs from the decomposition. | |
| No CV loss. Monitor only. The geometry is inherent. | |
| pip install "git+https://github.com/AbstractEyes/geolip-core.git" | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| import torchvision.transforms as T | |
| import math | |
| try: | |
| from geolip_core.linalg import svd as geolip_svd | |
| HAS_GEOLIP = True | |
| print("Using geolip-core SVD (Gram + eigh)") | |
| except ImportError: | |
| HAS_GEOLIP = False | |
| print("geolip-core not found, fallback to torch.svd_lowrank") | |
| # ββ CM for monitoring (not loss) ββ | |
| def cayley_menger_vol2(points): | |
| B, N, D = points.shape | |
| gram = torch.bmm(points, points.transpose(1, 2)) | |
| norms = torch.diagonal(gram, dim1=1, dim2=2) | |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) | |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) | |
| cm[:, 0, 1:] = 1.0 | |
| cm[:, 1:, 0] = 1.0 | |
| cm[:, 1:, 1:] = d2 | |
| k = N - 1 | |
| sign = (-1.0) ** (k + 1) | |
| fact = math.factorial(k) | |
| return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) | |
| def cv_of(emb, n_samples=200): | |
| if emb.dim() != 2 or emb.shape[0] < 5: | |
| return 0.0 | |
| N, D = emb.shape | |
| pool = min(N, 512) | |
| indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) | |
| vol2 = cayley_menger_vol2(emb[:pool][indices]) | |
| valid = vol2 > 1e-20 | |
| if valid.sum() < 10: | |
| return 0.0 | |
| vols = vol2[valid].sqrt() | |
| return (vols.std() / (vols.mean() + 1e-8)).item() | |
| BINDING_CONSTANT = 0.29154 | |
| # ββ Data ββ | |
| def get_cifar10(batch_size=256): | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), | |
| ]) | |
| train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
| test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) | |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) | |
| return train_loader, test_loader | |
| # ββ SVAE ββ | |
| class SVAE(nn.Module): | |
| def __init__(self, matrix_v=48, D=24): | |
| """ | |
| matrix_v: number of rows (vocabulary size of the implicit embedding) | |
| D: embedding dimension = number of singular values = 24 for binding constant | |
| """ | |
| super().__init__() | |
| self.matrix_v = matrix_v # V β number of embedding rows | |
| self.D = D # D β embedding dimension | |
| self.img_dim = 3 * 32 * 32 | |
| self.mat_dim = matrix_v * D | |
| self.encoder = nn.Sequential( | |
| nn.Linear(self.img_dim, 512), | |
| nn.GELU(), | |
| nn.Linear(512, 512), | |
| nn.GELU(), | |
| nn.Linear(512, self.mat_dim), | |
| ) | |
| self.decoder = nn.Sequential( | |
| nn.Linear(self.mat_dim, 512), | |
| nn.GELU(), | |
| nn.Linear(512, 512), | |
| nn.GELU(), | |
| nn.Linear(512, self.img_dim), | |
| ) | |
| def encode(self, images): | |
| B = images.shape[0] | |
| M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_v, self.D) | |
| if HAS_GEOLIP: | |
| U, S, Vh = geolip_svd(M) | |
| else: | |
| U, S, V = torch.svd_lowrank(M, q=self.D) | |
| Vh = V.transpose(1, 2) | |
| return { | |
| 'U': U, 'S': S, 'Vt': Vh, | |
| 'M': M, # the embedding matrix β rows are V points in D=24 | |
| } | |
| def decode_from_svd(self, U, S, Vt): | |
| B = U.shape[0] | |
| M_hat = torch.bmm(U * S.unsqueeze(1), Vt) | |
| return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32) | |
| def forward(self, images): | |
| svd = self.encode(images) | |
| recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt']) | |
| return {'recon': recon, 'svd': svd} | |
| def effective_rank(S): | |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) | |
| p = p.clamp(min=1e-8) | |
| return (-(p * p.log()).sum(-1)).exp() | |
| # ββ Training ββ | |
| def train(epochs=50, lr=1e-3, device='cuda'): | |
| device = torch.device(device if torch.cuda.is_available() else 'cpu') | |
| train_loader, test_loader = get_cifar10(batch_size=256) | |
| D = 24 | |
| V = 48 | |
| model = SVAE(matrix_v=V, D=D).to(device) | |
| opt = torch.optim.Adam(model.parameters(), lr=lr) | |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"SVAE β Structural Binding Constant") | |
| print(f" Matrix: ({V}, {D}) β {V} rows in D={D} space") | |
| print(f" Expected row CV β {BINDING_CONSTANT} (no loss, by construction)") | |
| print(f" SVD: {'geolip-core' if HAS_GEOLIP else 'torch.svd_lowrank'}") | |
| print(f" Compression: {model.img_dim} β {D} ({model.img_dim // D}:1)") | |
| print(f" Params: {total_params:,}") | |
| print("=" * 85) | |
| print(f"{'ep':>3} | {'loss':>7} {'recon':>7} | " | |
| f"{'t_recon':>7} | " | |
| f"{'S0':>6} {'SD':>6} {'ratio':>6} {'erank':>6} | " | |
| f"{'row_cv':>7} {'Ξbc':>7}") | |
| print("-" * 85) | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| total_loss, n = 0, 0 | |
| for images, labels in train_loader: | |
| images = images.to(device) | |
| opt.zero_grad() | |
| out = model(images) | |
| loss = F.mse_loss(out['recon'], images) | |
| loss.backward() | |
| opt.step() | |
| total_loss += loss.item() * len(images) | |
| n += len(images) | |
| sched.step() | |
| if epoch % 2 == 0 or epoch <= 3: | |
| model.eval() | |
| test_recon, test_n = 0, 0 | |
| test_S = None | |
| test_erank = 0 | |
| row_cvs = [] | |
| nb = 0 | |
| with torch.no_grad(): | |
| for images, labels in test_loader: | |
| images = images.to(device) | |
| out = model(images) | |
| test_recon += F.mse_loss(out['recon'], images).item() * len(images) | |
| test_n += len(images) | |
| test_erank += model.effective_rank(out['svd']['S']).mean().item() | |
| # CV of matrix rows: each M[i] is (V, D) β V points in D=24 | |
| # Sample a few to keep it fast | |
| if nb < 5: | |
| for b in range(min(4, len(images))): | |
| row_cvs.append(cv_of(out['svd']['M'][b])) | |
| if test_S is None: | |
| test_S = out['svd']['S'].mean(0).cpu() | |
| else: | |
| test_S += out['svd']['S'].mean(0).cpu() | |
| nb += 1 | |
| test_erank /= nb | |
| test_S /= nb | |
| ratio = (test_S[0] / (test_S[-1] + 1e-8)).item() | |
| mean_row_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0 | |
| delta_bc = abs(mean_row_cv - BINDING_CONSTANT) | |
| print(f"{epoch:3d} | {total_loss/n:7.4f} {total_loss/n:7.4f} | " | |
| f"{test_recon/test_n:7.4f} | " | |
| f"{test_S[0]:6.3f} {test_S[-1]:6.3f} {ratio:6.2f} " | |
| f"{test_erank:6.2f} | " | |
| f"{mean_row_cv:7.4f} {delta_bc:7.4f}") | |
| # ββ Final Analysis ββ | |
| print() | |
| print("=" * 85) | |
| print("FINAL ANALYSIS") | |
| print("=" * 85) | |
| model.eval() | |
| all_S, all_recon_err, all_labels = [], [], [] | |
| all_row_cvs = [] | |
| with torch.no_grad(): | |
| for images, labels in test_loader: | |
| images = images.to(device) | |
| out = model(images) | |
| all_S.append(out['svd']['S'].cpu()) | |
| all_recon_err.append( | |
| F.mse_loss(out['recon'], images, reduction='none') | |
| .mean(dim=(1, 2, 3)).cpu()) | |
| all_labels.append(labels.cpu()) | |
| # Row CV for a sample of images | |
| for b in range(min(8, len(images))): | |
| all_row_cvs.append(cv_of(out['svd']['M'][b])) | |
| all_S = torch.cat(all_S) | |
| all_recon_err = torch.cat(all_recon_err) | |
| all_labels = torch.cat(all_labels) | |
| erank = model.effective_rank(all_S) | |
| mean_row_cv = sum(all_row_cvs) / len(all_row_cvs) | |
| print(f"\n Architecture: ({V}, {D}) β {V} rows Γ D={D}") | |
| print(f" Recon MSE: {all_recon_err.mean():.6f} Β± {all_recon_err.std():.6f}") | |
| print(f" Effective rank: {erank.mean():.2f} Β± {erank.std():.2f}") | |
| print(f"\n Row CV (matrix rows as D={D} embedding):") | |
| print(f" Measured: {mean_row_cv:.4f}") | |
| print(f" Target: {BINDING_CONSTANT}") | |
| print(f" Delta: {abs(mean_row_cv - BINDING_CONSTANT):.4f}") | |
| print(f" {'β AT BINDING CONSTANT' if abs(mean_row_cv - BINDING_CONSTANT) < 0.01 else 'β Not at binding constant'}") | |
| # Spectrum profile | |
| S_mean = all_S.mean(0) | |
| total_energy = (S_mean ** 2).sum() | |
| print(f"\n Singular value profile:") | |
| cumulative = 0 | |
| for i in range(len(S_mean)): | |
| e = (S_mean[i] ** 2).item() | |
| cumulative += e | |
| pct = cumulative / total_energy * 100 | |
| bar = "β" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8)) | |
| print(f" S[{i:2d}]: {S_mean[i]:8.3f} cum={pct:5.1f}% {bar}") | |
| # Per-class | |
| cifar_names = ['plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| print(f"\n Per-class:") | |
| print(f" {'class':>6} {'recon':>8} {'erank':>6} {'S0':>7} {'SD':>7} {'ratio':>6}") | |
| for c in range(10): | |
| mask = all_labels == c | |
| rc = all_recon_err[mask].mean().item() | |
| er = erank[mask].mean().item() | |
| s0 = all_S[mask, 0].mean().item() | |
| sd = all_S[mask, -1].mean().item() | |
| ratio = s0 / (sd + 1e-8) | |
| print(f" {cifar_names[c]:>6} {rc:8.6f} {er:6.2f} {s0:7.3f} {sd:7.3f} {ratio:6.2f}") | |
| # Cross-class spectral variance | |
| class_S_means = torch.stack([all_S[all_labels == c].mean(0) for c in range(10)]) | |
| s_var = class_S_means.std(0) | |
| print(f"\n Cross-class S variance (top 5 most discriminative):") | |
| _, top_idx = s_var.topk(5) | |
| for idx in top_idx: | |
| i = idx.item() | |
| print(f" S[{i:2d}]: var={s_var[i]:.4f}") | |
| # ββ Reconstruction grid ββ | |
| print(f"\n Saving reconstruction grid...") | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) | |
| std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| images, labels = next(iter(test_loader)) | |
| images = images.to(device) | |
| out = model(images) | |
| selected_idx = [] | |
| for c in range(10): | |
| class_idx = (labels == c).nonzero(as_tuple=True)[0] | |
| selected_idx.extend(class_idx[:2].tolist()) | |
| orig = images[selected_idx] | |
| U = out['svd']['U'][selected_idx] | |
| S = out['svd']['S'][selected_idx] | |
| Vt = out['svd']['Vt'][selected_idx] | |
| mode_counts = [1, 4, 8, 16, D] | |
| mode_counts = list(dict.fromkeys([m for m in mode_counts if m <= D])) | |
| prog_recons = [] | |
| for n_modes in mode_counts: | |
| r = model.decode_from_svd(U[:, :, :n_modes], S[:, :n_modes], Vt[:, :n_modes, :]) | |
| prog_recons.append(r) | |
| def denorm(t): | |
| return (t * std_t + mean_t).clamp(0, 1).cpu() | |
| n_samples = len(selected_idx) | |
| n_cols = 2 + len(mode_counts) | |
| fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5)) | |
| col_titles = ['Original'] + [f'{m} mode{"s" if m > 1 else ""}' for m in mode_counts] + ['|Error|Γ5'] | |
| for i in range(n_samples): | |
| axes[i, 0].imshow(denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy()) | |
| for j, r in enumerate(prog_recons): | |
| axes[i, j+1].imshow(denorm(r[i:i+1])[0].permute(1, 2, 0).numpy()) | |
| err_col = 1 + len(prog_recons) | |
| diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5 | |
| axes[i, err_col].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()) | |
| c = labels[selected_idx[i]].item() | |
| axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35) | |
| for j, title in enumerate(col_titles): | |
| axes[0, j].set_title(title, fontsize=8) | |
| for ax in axes.flat: | |
| ax.axis('off') | |
| plt.tight_layout() | |
| plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight') | |
| print(f" Saved to /content/svae_recon_grid.png") | |
| try: | |
| plt.show() | |
| except: | |
| pass | |
| plt.close() | |
| if __name__ == "__main__": | |
| train() |