geolip-SVAE / prototype_v2.py
AbstractPhil's picture
Create prototype_v2.py
4d07d66 verified
"""
SVAE v1 β€” Clean SVD Autoencoder
==================================
The version that works. Uses geolip-core's optimized SVD kernel.
Image β†’ encoder MLP β†’ RECTANGULAR matrix (H, k) β†’ SVD β†’ decode β†’ image
Rectangular matrix means SVD returns exactly k values β€” no truncation.
For k≀12, geolip-core uses the FL eigh kernel (fully compilable, zero graph breaks).
For k>12, Gram + torch.linalg.eigh (still better backward than linalg.svd).
pip install "git+https://github.com/AbstractEyes/geolip-core.git"
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import math
try:
from geolip_core.linalg import svd as geolip_svd
HAS_GEOLIP = True
print("Using geolip-core SVD (Gram + eigh, compilable for k≀12)")
except ImportError:
HAS_GEOLIP = False
print("geolip-core not found, using torch.svd_lowrank fallback")
print('Install: pip install "git+https://github.com/AbstractEyes/geolip-core.git"')
# ── Data ──
def get_cifar10(batch_size=256):
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, test_loader
# ── SVAE ──
class SVAE(nn.Module):
def __init__(self, matrix_h=64, keep_k=8):
super().__init__()
self.matrix_h = matrix_h
self.matrix_k = keep_k # rectangular: (H, k) β€” SVD returns exactly k values
self.keep_k = keep_k
self.img_dim = 3 * 32 * 32
self.mat_dim = matrix_h * keep_k
self.encoder = nn.Sequential(
nn.Linear(self.img_dim, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, self.mat_dim),
)
self.decoder = nn.Sequential(
nn.Linear(self.mat_dim, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, self.img_dim),
)
def encode(self, images):
B = images.shape[0]
M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_h, self.matrix_k)
if HAS_GEOLIP:
# geolip-core: Gram + eigh, compilable for k≀12
U, S, Vh = geolip_svd(M)
else:
# Fallback: randomized truncated SVD
U, S, V = torch.svd_lowrank(M, q=self.keep_k)
Vh = V.transpose(1, 2)
return {
'U': U, 'S': S, 'Vt': Vh,
'M': M,
}
def decode_from_svd(self, U, S, Vt):
B = U.shape[0]
M_hat = torch.bmm(U * S.unsqueeze(1), Vt)
return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32)
def forward(self, images):
svd = self.encode(images)
recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt'])
return {'recon': recon, 'svd': svd}
@staticmethod
def effective_rank(S):
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
return (-(p * p.log()).sum(-1)).exp()
# ── Training ──
def train(epochs=50, lr=1e-3, keep_k=16, device='cuda'):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
train_loader, test_loader = get_cifar10(batch_size=256)
model = SVAE(matrix_h=64, keep_k=keep_k).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
total_params = sum(p.numel() for p in model.parameters())
print(f"SVAE v1 β€” Clean SVD Autoencoder")
print(f" Matrix: ({model.matrix_h}, {model.matrix_k}) β†’ rectangular, exact {keep_k} singular values")
print(f" SVD: {'geolip-core Gram+eigh' if HAS_GEOLIP else 'torch.svd_lowrank'}"
f"{' (FL compilable)' if HAS_GEOLIP and keep_k <= 12 else ''}")
print(f" Compression: {model.img_dim} β†’ {keep_k} ({model.img_dim // keep_k}:1)")
print(f" Params: {total_params:,}")
print(f" Device: {device}")
print("=" * 70)
print(f"{'ep':>3} | {'loss':>7} {'recon':>7} | "
f"{'t_recon':>7} | "
f"{'S0':>7} {'S1':>7} {'Sk-1':>7}")
print("-" * 70)
for epoch in range(1, epochs + 1):
model.train()
total_loss, total_recon, n = 0, 0, 0
for images, labels in train_loader:
images = images.to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
loss = recon_loss
loss.backward()
opt.step()
total_loss += loss.item() * len(images)
total_recon += recon_loss.item() * len(images)
n += len(images)
sched.step()
if epoch % 2 == 0 or epoch <= 3:
model.eval()
test_recon, test_n = 0, 0
test_S = None
nb = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
test_recon += F.mse_loss(out['recon'], images).item() * len(images)
test_n += len(images)
if test_S is None:
test_S = out['svd']['S'].mean(0).cpu()
else:
test_S += out['svd']['S'].mean(0).cpu()
nb += 1
test_S /= nb
print(f"{epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} | "
f"{test_recon/test_n:7.4f} | "
f"{test_S[0]:7.3f} {test_S[1]:7.3f} {test_S[-1]:7.3f}")
# ── Final Analysis ──
print()
print("=" * 85)
print("FINAL ANALYSIS")
print("=" * 85)
model.eval()
all_S, all_recon_err, all_labels = [], [], []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
all_S.append(out['svd']['S'].cpu())
all_recon_err.append(
F.mse_loss(out['recon'], images, reduction='none')
.mean(dim=(1, 2, 3)).cpu())
all_labels.append(labels.cpu())
all_S = torch.cat(all_S)
all_recon_err = torch.cat(all_recon_err)
all_labels = torch.cat(all_labels)
erank = model.effective_rank(all_S)
print(f"\n Bottleneck: {keep_k} singular values (truncated SVD)")
print(f" Effective rank: {erank.mean():.2f} Β± {erank.std():.2f}")
print(f" Recon MSE: {all_recon_err.mean():.6f} Β± {all_recon_err.std():.6f}")
# Singular value profile
S_mean = all_S.mean(0)
total_energy = (S_mean ** 2).sum()
print(f"\n Singular value profile:")
cumulative = 0
for i in range(len(S_mean)):
e = (S_mean[i] ** 2).item()
cumulative += e
pct = cumulative / total_energy * 100
bar = "β–ˆ" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8))
print(f" S[{i:2d}]: {S_mean[i]:8.3f} cum_energy={pct:5.1f}% {bar}")
# Per-class
cifar_names = ['plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(f"\n Per-class:")
print(f" {'class':>6} {'recon':>8} {'S0':>7} {'S1':>7} {'erank':>6}")
for c in range(10):
mask = all_labels == c
rc = all_recon_err[mask].mean().item()
s0 = all_S[mask, 0].mean().item()
s1 = all_S[mask, 1].mean().item()
er = erank[mask].mean().item()
print(f" {cifar_names[c]:>6} {rc:8.6f} {s0:7.3f} {s1:7.3f} {er:6.2f}")
# ── Save reconstruction grid ──
print(f"\n Saving reconstruction grid...")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# CIFAR-10 denormalization
mean = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device)
std = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device)
model.eval()
with torch.no_grad():
# Grab one batch
images, labels = next(iter(test_loader))
images = images.to(device)
out = model(images)
# Pick 2 samples per class = 20 images
selected_idx = []
for c in range(10):
class_idx = (labels == c).nonzero(as_tuple=True)[0]
selected_idx.extend(class_idx[:2].tolist())
orig = images[selected_idx]
recon_full = out['recon'][selected_idx]
# Progressive reconstructions: 1, 4, 8, 16 modes
U = out['svd']['U'][selected_idx]
S = out['svd']['S'][selected_idx]
Vt = out['svd']['Vt'][selected_idx]
mode_counts = [1, 4, 8, keep_k]
# deduplicate if keep_k is already in the list
mode_counts = list(dict.fromkeys(mode_counts))
prog_recons = []
for n_modes in mode_counts:
n_modes = min(n_modes, S.shape[1])
U_n = U[:, :, :n_modes]
S_n = S[:, :n_modes]
Vt_n = Vt[:, :n_modes, :]
r = model.decode_from_svd(U_n, S_n, Vt_n)
prog_recons.append(r)
def denorm(t):
return (t * std + mean).clamp(0, 1).cpu()
n_samples = len(selected_idx)
n_cols = 2 + len(mode_counts) # original + progressives + error
fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5))
col_titles = ['Original'] + [f'{m} mode{"s" if m > 1 else ""}' for m in mode_counts] + ['|Error|Γ—5']
for i in range(n_samples):
# Original
img_orig = denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy()
axes[i, 0].imshow(img_orig)
# Progressive
for j, r in enumerate(prog_recons):
img_r = denorm(r[i:i+1])[0].permute(1, 2, 0).numpy()
axes[i, j+1].imshow(img_r)
# Error map (amplified 5Γ—)
err_col = 1 + len(prog_recons)
diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5
diff = diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()
axes[i, err_col].imshow(diff)
# Class label
c = labels[selected_idx[i]].item()
axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35)
for j, title in enumerate(col_titles):
axes[0, j].set_title(title, fontsize=8)
for ax in axes.flat:
ax.axis('off')
plt.tight_layout()
plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight')
print(f" Saved to /content/svae_recon_grid.png")
try:
plt.show()
except:
pass
plt.close()
if __name__ == "__main__":
train(keep_k=16)