geolip-SVAE / prototype_v9_tiny_imagenet.py
AbstractPhil's picture
Update prototype_v9_tiny_imagenet.py
d9cb834 verified
"""
SVAE β€” SVD Autoencoder with Geometric Attractors
===================================================
A matrix-valued autoencoder where the latent space is a (V, D) matrix
decomposed by SVD. Rows are normalized to S^(D-1), making the geometric
structure architectural rather than loss-dependent.
Two key mechanisms:
1. Sphere normalization: F.normalize(M, dim=-1) constrains rows to unit
vectors on S^(D-1). This bounds the Gram matrix, eliminates training
instabilities, and makes the CV a structural property of (V, D).
2. Soft hand: An oscillatory counterweight that boosts reconstruction
gradients when geometry is near target, and penalizes CV drift when
geometry is far from target. Provides positive momentum, not just penalty.
Architecture: Image β†’ MLP β†’ M ∈ ℝ^(VΓ—D) β†’ normalize β†’ SVD β†’ MLP β†’ Recon
Repository: AbstractEyes/geolip-core
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import math
import time
# ── SVD Backend ──────────────────────────────────────────────────
try:
from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N
HAS_FL = True
except ImportError:
HAS_FL = False
def gram_eigh_svd_fp64(A):
"""Thin SVD via Gram matrix + eigh, computed entirely in fp64.
fp64 is essential: Gram entries scale as Sβ‚€Β², and fp32 (~7 digits)
causes catastrophic collapses when the condition number exceeds ~100.
fp64 (~15 digits) eliminates this failure mode entirely.
Args:
A: (B, M, N) tensor, M >= N
Returns:
U (B,M,N), S (B,N), Vh (B,N,N) β€” singular values descending.
"""
orig_dtype = A.dtype
with torch.amp.autocast('cuda', enabled=False):
A_d = A.double()
G = torch.bmm(A_d.transpose(1, 2), A_d)
eigenvalues, V = torch.linalg.eigh(G)
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
def svd_fp64(A):
"""Auto-dispatch SVD with fp64 internals.
N <= 12 + FLEigh available: Gram in fp64, FL eigh (compilable).
N > 12 or CPU: Gram + torch.linalg.eigh in fp64.
Triton bypassed β€” fp32-only hardware, incompatible with fp64.
"""
B, M, N = A.shape
if HAS_FL and N <= _FL_MAX_N and A.is_cuda:
orig_dtype = A.dtype
with torch.amp.autocast('cuda', enabled=False):
A_d = A.double()
G = torch.bmm(A_d.transpose(1, 2), A_d)
eigenvalues, V = FLEigh()(G.float()) # FL needs fp32 input
eigenvalues = eigenvalues.double().flip(-1)
V = V.double().flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
else:
return gram_eigh_svd_fp64(A)
# ── Cayley-Menger CV Monitoring ──────────────────────────────────
def cayley_menger_vol2(points):
"""Squared simplex volume via Cayley-Menger determinant, in fp64.
Args: points (B, N, D) β€” B simplices, each with N vertices in D dims.
Returns: (B,) squared volumes.
"""
B, N, D = points.shape
pts = points.double()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
def cv_of(emb, n_samples=200):
"""CV of pentachoron volumes for a single embedding matrix.
Measures geometric regularity: low CV = regular, high CV = irregular.
Args: emb (V, D) tensor.
Returns: float CV value, or 0.0 if insufficient data.
"""
if emb.dim() != 2 or emb.shape[0] < 5:
return 0.0
N, D = emb.shape
pool = min(N, 512)
indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)])
vol2 = cayley_menger_vol2(emb[:pool][indices])
valid = vol2 > 1e-20
if valid.sum() < 10:
return 0.0
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
# ── Data ─────────────────────────────────────────────────────────
def get_cifar10(batch_size=256):
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, test_loader, 3 * 32 * 32, 10, \
['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def get_tiny_imagenet(batch_size=256):
"""TinyImageNet via HuggingFace: 200 classes, 64x64, 100K train / 10K val."""
from datasets import load_dataset
ds = load_dataset('zh-plus/tiny-imagenet')
mean = (0.4802, 0.4481, 0.3975)
std = (0.2770, 0.2691, 0.2821)
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean, std),
])
class HFImageDataset(torch.utils.data.Dataset):
def __init__(self, hf_split, transform):
self.data = hf_split
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
img = item['image']
if img.mode != 'RGB':
img = img.convert('RGB')
return self.transform(img), item['label']
train_ds = HFImageDataset(ds['train'], transform)
val_ds = HFImageDataset(ds['valid'], transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
class_names = [f'c{i:03d}' for i in range(200)]
return train_loader, val_loader, 3 * 64 * 64, 200, class_names
# ── SVAE Model ───────────────────────────────────────────────────
class SVAE(nn.Module):
"""SVD Autoencoder with sphere-normalized matrix latent space.
The encoder produces a (V, D) matrix whose rows are normalized to S^(D-1).
The SVD decomposes alignment structure (U, V) from spectral magnitudes (S).
The decoder reconstructs from the full SVD: MΜ‚ = UΞ£Vα΅€.
Args:
matrix_v: Number of rows V (vocabulary size / overcomplete factor)
D: Embedding dimension (number of singular values)
img_dim: Flattened image dimension (3*H*W)
hidden: Hidden layer width (auto-scaled if None)
"""
def __init__(self, matrix_v=96, D=24, img_dim=3072, hidden=None):
super().__init__()
self.matrix_v = matrix_v
self.D = D
self.img_dim = img_dim
self.mat_dim = matrix_v * D
h = hidden or max(512, min(2048, img_dim // 4))
self.encoder = nn.Sequential(
nn.Linear(self.img_dim, h),
nn.GELU(),
nn.Linear(h, h),
nn.GELU(),
nn.Linear(h, self.mat_dim),
)
self.decoder = nn.Sequential(
nn.Linear(self.mat_dim, h),
nn.GELU(),
nn.Linear(h, h),
nn.GELU(),
nn.Linear(h, self.img_dim),
)
nn.init.orthogonal_(self.encoder[-1].weight)
def encode(self, images):
B = images.shape[0]
M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_v, self.D)
M = F.normalize(M, dim=-1) # rows to S^(D-1)
U, S, Vh = svd_fp64(M)
return {'U': U, 'S': S, 'Vt': Vh, 'M': M}
def decode_from_svd(self, U, S, Vt):
B = U.shape[0]
M_hat = torch.bmm(U * S.unsqueeze(1), Vt)
flat = self.decoder(M_hat.reshape(B, -1))
# Infer spatial dims from img_dim: 3*H*W
hw = self.img_dim // 3
h = int(hw ** 0.5)
return flat.reshape(B, 3, h, h)
def forward(self, images):
svd = self.encode(images)
recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt'])
return {'recon': recon, 'svd': svd}
@staticmethod
def effective_rank(S):
"""Shannon entropy effective rank of singular value spectrum."""
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
return (-(p * p.log()).sum(-1)).exp()
# ── Training ─────────────────────────────────────────────────────
def train(epochs=100, lr=1e-3, V=256, D=24, target_cv=0.125,
cv_weight=0.3, boost=0.5, sigma=0.15,
dataset='cifar10', device='cuda'):
"""Train the SVAE with sphere normalization + soft hand.
Args:
epochs: Training epochs
lr: Learning rate for Adam
V: Matrix rows (vocabulary size)
D: Embedding dimension
target_cv: CV attractor target for soft hand
cv_weight: Maximum CV penalty weight (far from target)
boost: Maximum reconstruction boost factor (near target)
sigma: Gaussian transition width for proximity
dataset: 'cifar10' or 'tiny_imagenet'
device: Training device
"""
device = torch.device(device if torch.cuda.is_available() else 'cpu')
if dataset == 'tiny_imagenet':
train_loader, test_loader, img_dim, n_classes, class_names = get_tiny_imagenet(batch_size=256)
else:
train_loader, test_loader, img_dim, n_classes, class_names = get_cifar10(batch_size=256)
img_h = int((img_dim // 3) ** 0.5)
model = SVAE(matrix_v=V, D=D, img_dim=img_dim).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
total_params = sum(p.numel() for p in model.parameters())
# ── Header ──
svd_backend = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})"
print(f"Using geolip-core SVD ({svd_backend})")
print(f"SVAE - V={V}, D={D}, rows on S^{D-1} + soft hand")
print(f" Dataset: {dataset} ({img_h}x{img_h}, {n_classes} classes)")
print(f" Matrix: ({V}, {D}) = {V*D} elements, rows normalized")
print(f" SVD: fp64 Gram+eigh")
print(f" Sphere: rows on S^{D-1} (structural geometry)")
print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far")
print(f" Params: {total_params:,}")
print("=" * 90)
print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | "
f"{'t_rec':>7} | "
f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | "
f"{'row_cv':>7} {'prox':>5} {'rw':>5}")
print("-" * 90)
# ── Training loop ──
for epoch in range(1, epochs + 1):
model.train()
total_loss, total_recon, n = 0, 0, 0
last_cv = target_cv
last_prox = 1.0
recon_w = 1.0 + boost
t0 = time.time()
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
# Measure CV and compute proximity (every 10th batch)
with torch.no_grad():
if batch_idx % 10 == 0:
current_cv = cv_of(out['svd']['M'][0])
if current_cv > 0:
last_cv = current_cv
delta = last_cv - target_cv
last_prox = math.exp(-delta**2 / (2 * sigma**2))
# Soft hand: boost recon near target, penalize CV far from target
recon_w = 1.0 + boost * last_prox
cv_pen = cv_weight * (1.0 - last_prox)
cv_l = (last_cv - target_cv) ** 2
loss = recon_w * recon_loss + cv_pen * cv_l
loss.backward()
opt.step()
total_loss += loss.item() * len(images)
total_recon += recon_loss.item() * len(images)
n += len(images)
sched.step()
epoch_time = time.time() - t0
# ── Evaluation (every 2 epochs + first 3) ──
if epoch % 2 == 0 or epoch <= 3:
model.eval()
test_recon, test_n = 0, 0
test_S, test_erank = None, 0
row_cvs = []
nb = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
test_recon += F.mse_loss(out['recon'], images).item() * len(images)
test_n += len(images)
test_erank += model.effective_rank(out['svd']['S']).mean().item()
if nb < 5:
for b in range(min(4, len(images))):
row_cvs.append(cv_of(out['svd']['M'][b]))
if test_S is None:
test_S = out['svd']['S'].mean(0).cpu()
else:
test_S += out['svd']['S'].mean(0).cpu()
nb += 1
test_erank /= nb
test_S /= nb
ratio = (test_S[0] / (test_S[-1] + 1e-8)).item()
mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0
print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | "
f"{test_recon/test_n:7.4f} | "
f"{test_S[0]:6.3f} {test_S[-1]:6.3f} {ratio:5.2f} "
f"{test_erank:5.2f} | "
f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f}")
# ── Final Analysis ──
print()
print("=" * 85)
print("FINAL ANALYSIS")
print("=" * 85)
model.eval()
all_S, all_recon_err, all_labels = [], [], []
all_row_cvs = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
all_S.append(out['svd']['S'].cpu())
all_recon_err.append(
F.mse_loss(out['recon'], images, reduction='none')
.mean(dim=(1, 2, 3)).cpu())
all_labels.append(labels.cpu())
for b in range(min(8, len(images))):
all_row_cvs.append(cv_of(out['svd']['M'][b]))
all_S = torch.cat(all_S)
all_recon_err = torch.cat(all_recon_err)
all_labels = torch.cat(all_labels)
erank = model.effective_rank(all_S)
mean_cv = sum(all_row_cvs) / len(all_row_cvs)
print(f"\n V={V}, D={D}, rows on S^{D-1}")
print(f" Target CV: {target_cv}")
print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}")
print(f" Effective rank: {erank.mean():.2f} +/- {erank.std():.2f}")
print(f" Row CV: {mean_cv:.4f}")
S_mean = all_S.mean(0)
total_energy = (S_mean ** 2).sum()
print(f"\n Singular value profile:")
cumulative = 0
for i in range(len(S_mean)):
e = (S_mean[i] ** 2).item()
cumulative += e
pct = cumulative / total_energy * 100
bar = "#" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8))
print(f" S[{i:2d}]: {S_mean[i]:8.4f} cum={pct:5.1f}% {bar}")
# Per-class summary (cap at 20 classes for readability)
show_classes = min(n_classes, 20)
print(f"\n Per-class (showing {show_classes}/{n_classes}):")
print(f" {'cls':>8} {'recon':>8} {'erank':>6} {'S0':>7} {'SD':>7} {'ratio':>6}")
for c in range(show_classes):
mask = all_labels == c
if mask.sum() == 0:
continue
rc = all_recon_err[mask].mean().item()
er = erank[mask].mean().item()
s0 = all_S[mask, 0].mean().item()
sd = all_S[mask, -1].mean().item()
r = s0 / (sd + 1e-8)
name = class_names[c] if c < len(class_names) else f'cls_{c}'
print(f" {name:>8} {rc:8.6f} {er:6.2f} {s0:7.4f} {sd:7.4f} {r:6.2f}")
# ── Reconstruction Grid ──
print(f"\n Saving reconstruction grid...")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
if dataset == 'tiny_imagenet':
mean_t = torch.tensor([0.4802, 0.4481, 0.3975]).reshape(1, 3, 1, 1).to(device)
std_t = torch.tensor([0.2770, 0.2691, 0.2821]).reshape(1, 3, 1, 1).to(device)
else:
mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device)
std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device)
# Select up to 2 samples from up to 10 classes
grid_classes = min(n_classes, 10)
model.eval()
with torch.no_grad():
images, labels = next(iter(test_loader))
images = images.to(device)
out = model(images)
selected_idx = []
for c in range(grid_classes):
class_idx = (labels == c).nonzero(as_tuple=True)[0]
selected_idx.extend(class_idx[:2].tolist())
if not selected_idx:
selected_idx = list(range(min(20, len(images))))
orig = images[selected_idx]
U = out['svd']['U'][selected_idx]
S = out['svd']['S'][selected_idx]
Vt = out['svd']['Vt'][selected_idx]
mode_counts = [1, 4, 8, 16, D]
prog_recons = []
for nm in mode_counts:
r = model.decode_from_svd(U[:, :, :nm], S[:, :nm], Vt[:, :nm, :])
prog_recons.append(r)
def denorm(t):
return (t * std_t + mean_t).clamp(0, 1).cpu()
n_samples = len(selected_idx)
n_cols = 2 + len(mode_counts)
fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5))
col_titles = ['Original'] + [f'{m} modes' for m in mode_counts] + ['|Err|x5']
for i in range(n_samples):
axes[i, 0].imshow(denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy())
for j, r in enumerate(prog_recons):
axes[i, j+1].imshow(denorm(r[i:i+1])[0].permute(1, 2, 0).numpy())
err_col = 1 + len(prog_recons)
diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5
axes[i, err_col].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy())
c = labels[selected_idx[i]].item()
name = class_names[c] if c < len(class_names) else f'{c}'
axes[i, 0].set_ylabel(name, fontsize=7, rotation=0, labelpad=40)
for j, title in enumerate(col_titles):
axes[0, j].set_title(title, fontsize=8)
for ax in axes.flat:
ax.axis('off')
plt.tight_layout()
plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight')
print(f" Saved to /content/svae_recon_grid.png")
try:
plt.show()
except:
pass
plt.close()
if __name__ == "__main__":
train(epochs=200, V=256, D=24, target_cv=0.45, dataset='tiny_imagenet')