geolip-SVAE / 5m_svae_proto_1024_v3_kl_divergence.py
AbstractPhil's picture
Create 5m_svae_proto_1024_v3_kl_divergence.py (#1)
b6c4e22
raw
history blame
15.1 kB
"""
Fault; corruption from KL_Divergence corrupts the SVD and eigh solidity
SVAE - V=1024, D=24 (Validated Binding Constant)
==================================================
V=1024, D=24 -> CV=0.2916 (from sweep, confirmed)
Deep encoder/decoder for 1024x24 = 24,576 matrix.
Light KL on spectral shape (don't constrain magnitude).
Row CV should be ~0.29 by dimensional law.
pip install "git+https://github.com/AbstractEyes/geolip-core.git"
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import math
try:
from geolip_core.linalg import svd as geolip_svd
HAS_GEOLIP = True
print("Using geolip-core SVD (Gram + eigh)")
except ImportError:
HAS_GEOLIP = False
print("geolip-core not found, fallback to torch.svd_lowrank")
# -- CM monitoring --
def cayley_menger_vol2(points):
B, N, D = points.shape
gram = torch.bmm(points, points.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
def cv_of(emb, n_samples=200):
if emb.dim() != 2 or emb.shape[0] < 5:
return 0.0
N, D = emb.shape
pool = min(N, 512)
indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)])
vol2 = cayley_menger_vol2(emb[:pool][indices])
valid = vol2 > 1e-20
if valid.sum() < 10:
return 0.0
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
def safe_gram_svd(M, eps=1e-6):
"""Gram + eigh SVD with diagonal regularization for conditioning."""
with torch.amp.autocast('cuda', enabled=False):
A = M.float()
G = torch.bmm(A.transpose(1, 2), A)
# Regularize: add eps to diagonal for numerical stability
G = G + eps * torch.eye(G.shape[-1], device=G.device, dtype=G.dtype).unsqueeze(0)
eigenvalues, V = torch.linalg.eigh(G)
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-12))
U = torch.bmm(A, V) / S.unsqueeze(1).clamp(min=1e-8)
Vh = V.transpose(-2, -1).contiguous()
return U, S, Vh
BINDING_CONSTANT = 0.29154
# -- Data --
def get_cifar10(batch_size=256):
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, test_loader
# -- SVAE --
class SVAE(nn.Module):
def __init__(self, matrix_v=1024, D=24):
super().__init__()
self.matrix_v = matrix_v
self.D = D
self.img_dim = 3 * 32 * 32
self.mat_dim = matrix_v * D
# Deep encoder for 1024x24 = 24,576 output
self.encoder = nn.Sequential(
nn.Linear(self.img_dim, 1024),
nn.GELU(),
nn.Linear(1024, 2048),
nn.GELU(),
nn.Linear(2048, self.mat_dim),
)
self.decoder = nn.Sequential(
nn.Linear(self.mat_dim, 2048),
nn.GELU(),
nn.Linear(2048, 1024),
nn.GELU(),
nn.Linear(1024, self.img_dim),
)
# Spectral log-variance (shape regularization only)
self.logvar_head = nn.Sequential(
nn.Linear(2048, 128), # tap from encoder hidden, not full mat_dim
nn.GELU(),
nn.Linear(128, D),
)
# Prior: SHAPE only, not magnitude
# Normalized decay from 1.0 to ~0.14 in log space
# The prior says "S should decay smoothly" not "S should be small"
self.register_buffer('prior_log_mu', torch.linspace(0, -2, D))
self.register_buffer('prior_log_var', torch.ones(D)) # wide prior (var=e^1 ~2.7)
def encode(self, images):
B = images.shape[0]
flat = images.reshape(B, -1)
# Run encoder with hidden tap for logvar
h1 = F.gelu(self.encoder[0](flat)) # 3072 -> 1024
h2 = F.gelu(self.encoder[2](h1)) # 1024 -> 2048
mat_flat = self.encoder[4](h2) # 2048 -> mat_dim
M = mat_flat.reshape(B, self.matrix_v, self.D)
if HAS_GEOLIP:
try:
U, S, Vh = geolip_svd(M)
except Exception:
U, S, Vh = safe_gram_svd(M)
else:
U, S, Vh = safe_gram_svd(M)
# Log-variance from hidden (not full mat_dim - too expensive)
log_var = self.logvar_head(h2)
# Reparameterize on NORMALIZED spectrum (shape, not magnitude)
if self.training:
S_norm = S / (S[:, 0:1] + 1e-8) # normalize by S[0]
log_S_norm = torch.log(S_norm.clamp(min=1e-8))
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
S_norm_sampled = torch.exp(log_S_norm + std * eps)
S_norm_sampled, _ = S_norm_sampled.sort(dim=-1, descending=True)
# Denormalize back
S_sampled = S_norm_sampled * S[:, 0:1]
else:
S_sampled = S
S_norm = S / (S[:, 0:1] + 1e-8)
return {
'U': U, 'S': S, 'S_sampled': S_sampled, 'Vt': Vh,
'S_norm': S / (S[:, 0:1] + 1e-8),
'M': M, 'log_var': log_var,
}
def decode_from_svd(self, U, S, Vt):
B = U.shape[0]
M_hat = torch.bmm(U * S.unsqueeze(1), Vt)
return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32)
def spectral_kl(self, S_norm, log_var):
"""KL on NORMALIZED spectrum shape. Magnitude-free."""
log_S = torch.log(S_norm.clamp(min=1e-8))
mu_q = log_S
var_q = torch.exp(log_var)
mu_p = self.prior_log_mu.unsqueeze(0)
var_p = torch.exp(self.prior_log_var).unsqueeze(0)
kl = 0.5 * (var_q / var_p + (mu_p - mu_q).pow(2) / var_p
- 1 + torch.log(var_p / (var_q + 1e-8)))
return kl.sum(dim=-1).mean()
def forward(self, images):
svd = self.encode(images)
recon = self.decode_from_svd(svd['U'], svd['S_sampled'], svd['Vt'])
kl = self.spectral_kl(svd['S_norm'], svd['log_var'])
return {'recon': recon, 'svd': svd, 'kl': kl}
@staticmethod
def effective_rank(S):
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
return (-(p * p.log()).sum(-1)).exp()
# -- Training --
def train(epochs=50, lr=1e-3, kl_weight=0.001, device='cuda'):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
train_loader, test_loader = get_cifar10(batch_size=256)
model = SVAE(matrix_v=1024, D=24).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
total_params = sum(p.numel() for p in model.parameters())
print(f"SVAE - V=1024, D=24 (Validated: CV=0.2916)")
print(f" Matrix: (1024, 24) = 24,576 elements")
print(f" KL on normalized spectrum shape (weight={kl_weight})")
print(f" Wide prior (var=e, shape-only)")
print(f" Params: {total_params:,}")
print("=" * 90)
print(f"{'ep':>3} | {'loss':>7} {'recon':>7} {'kl':>7} | "
f"{'t_rec':>7} | "
f"{'S0':>7} {'SD':>6} {'ratio':>5} {'erank':>5} | "
f"{'row_cv':>7} {'Sn_shape':>20}")
print("-" * 90)
for epoch in range(1, epochs + 1):
model.train()
total_loss, total_recon, total_kl, n = 0, 0, 0, 0
for images, labels in train_loader:
images = images.to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
kl = out['kl']
loss = recon_loss + kl_weight * kl
loss.backward()
opt.step()
total_loss += loss.item() * len(images)
total_recon += recon_loss.item() * len(images)
total_kl += kl.item() * len(images)
n += len(images)
sched.step()
if epoch % 2 == 0 or epoch <= 3:
model.eval()
test_recon, test_n = 0, 0
test_S, test_erank = None, 0
row_cvs = []
test_S_norm = None
nb = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
test_recon += F.mse_loss(out['recon'], images).item() * len(images)
test_n += len(images)
test_erank += model.effective_rank(out['svd']['S']).mean().item()
if nb < 3:
for b in range(min(4, len(images))):
row_cvs.append(cv_of(out['svd']['M'][b]))
if test_S is None:
test_S = out['svd']['S'].mean(0).cpu()
test_S_norm = out['svd']['S_norm'].mean(0).cpu()
else:
test_S += out['svd']['S'].mean(0).cpu()
test_S_norm += out['svd']['S_norm'].mean(0).cpu()
nb += 1
test_erank /= nb
test_S /= nb
test_S_norm /= nb
ratio = (test_S[0] / (test_S[-1] + 1e-8)).item()
mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0
# Show normalized spectrum shape (first 5)
shape_str = " ".join(f"{s:.3f}" for s in test_S_norm[:5])
print(f"{epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} "
f"{total_kl/n:7.3f} | "
f"{test_recon/test_n:7.4f} | "
f"{test_S[0]:7.2f} {test_S[-1]:6.3f} {ratio:5.2f} "
f"{test_erank:5.2f} | "
f"{mean_cv:7.4f} [{shape_str}]")
# -- Final Analysis --
print()
print("=" * 90)
print("FINAL ANALYSIS")
print("=" * 90)
model.eval()
all_S, all_recon_err, all_labels = [], [], []
all_row_cvs = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
all_S.append(out['svd']['S'].cpu())
all_recon_err.append(
F.mse_loss(out['recon'], images, reduction='none')
.mean(dim=(1, 2, 3)).cpu())
all_labels.append(labels.cpu())
for b in range(min(4, len(images))):
all_row_cvs.append(cv_of(out['svd']['M'][b]))
all_S = torch.cat(all_S)
all_recon_err = torch.cat(all_recon_err)
all_labels = torch.cat(all_labels)
erank = model.effective_rank(all_S)
mean_cv = sum(all_row_cvs) / len(all_row_cvs)
print(f"\n V=1024, D=24 (validated CV=0.2916)")
print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}")
print(f" Effective rank: {erank.mean():.2f} +/- {erank.std():.2f}")
print(f" Row CV: {mean_cv:.4f} (target: {BINDING_CONSTANT}, delta: {abs(mean_cv - BINDING_CONSTANT):.4f})")
# Spectrum
S_mean = all_S.mean(0)
S_norm = S_mean / (S_mean[0] + 1e-8)
total_energy = (S_mean ** 2).sum()
print(f"\n Singular value profile (raw and normalized):")
cumulative = 0
for i in range(len(S_mean)):
e = (S_mean[i] ** 2).item()
cumulative += e
pct = cumulative / total_energy * 100
bar = "#" * int(S_norm[i].item() * 30)
print(f" S[{i:2d}]: {S_mean[i]:8.3f} norm={S_norm[i]:.4f} cum={pct:5.1f}% {bar}")
# Per-class
cifar_names = ['plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(f"\n Per-class:")
print(f" {'cls':>6} {'recon':>8} {'erank':>6} {'S0':>7} {'SD':>7} {'ratio':>6}")
for c in range(10):
mask = all_labels == c
rc = all_recon_err[mask].mean().item()
er = erank[mask].mean().item()
s0 = all_S[mask, 0].mean().item()
sd = all_S[mask, -1].mean().item()
r = s0 / (sd + 1e-8)
print(f" {cifar_names[c]:>6} {rc:8.6f} {er:6.2f} {s0:7.3f} {sd:7.3f} {r:6.2f}")
# -- Recon grid --
print(f"\n Saving reconstruction grid...")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device)
std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device)
model.eval()
with torch.no_grad():
images, labels = next(iter(test_loader))
images = images.to(device)
out = model(images)
selected_idx = []
for c in range(10):
class_idx = (labels == c).nonzero(as_tuple=True)[0]
selected_idx.extend(class_idx[:2].tolist())
orig = images[selected_idx]
U = out['svd']['U'][selected_idx]
S = out['svd']['S'][selected_idx]
Vt = out['svd']['Vt'][selected_idx]
mode_counts = [1, 4, 8, 16, 24]
prog_recons = []
for nm in mode_counts:
r = model.decode_from_svd(U[:, :, :nm], S[:, :nm], Vt[:, :nm, :])
prog_recons.append(r)
def denorm(t):
return (t * std_t + mean_t).clamp(0, 1).cpu()
n_samples = len(selected_idx)
n_cols = 2 + len(mode_counts)
fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5))
col_titles = ['Original'] + [f'{m} modes' for m in mode_counts] + ['|Err|x5']
for i in range(n_samples):
axes[i, 0].imshow(denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy())
for j, r in enumerate(prog_recons):
axes[i, j+1].imshow(denorm(r[i:i+1])[0].permute(1, 2, 0).numpy())
err_col = 1 + len(prog_recons)
diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5
axes[i, err_col].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy())
c = labels[selected_idx[i]].item()
axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35)
for j, title in enumerate(col_titles):
axes[0, j].set_title(title, fontsize=8)
for ax in axes.flat:
ax.axis('off')
plt.tight_layout()
plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight')
print(f" Saved to /content/svae_recon_grid.png")
try:
plt.show()
except:
pass
plt.close()
if __name__ == "__main__":
train()