geolip-SVAE / prototype_v12_128x128_patch16.py
AbstractPhil's picture
Create prototype_v12_128x128_patch16.py
cbf8c7e verified
raw
history blame
34.8 kB
"""
SVAE-Patch β€” Patch-based SVD Autoencoder
==========================================
64Γ—64 image β†’ 4 patches of 32Γ—32 β†’ independent SVD per patch β†’ coordinate β†’ decode.
Each patch gets the proven (V, D) pipeline:
patch β†’ MLP β†’ M ∈ ℝ^(VΓ—D) β†’ normalize β†’ SVD β†’ (U, S, Vt)
Cross-patch coordination via a lightweight attention on the spectral
representations β€” each patch's S vector (D-dim) attends to all others.
This lets patches share information about relative spectral structure
without disrupting the per-patch geometric attractors.
Reconstruction: coordinated spectra + per-patch (U, Vt) β†’ MLP β†’ patch β†’ stitch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import math
import time
# ── SVD Backend ──────────────────────────────────────────────────
try:
from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N
HAS_FL = True
except ImportError:
HAS_FL = False
def gram_eigh_svd_fp64(A):
"""Thin SVD via Gram + eigh in fp64."""
orig_dtype = A.dtype
with torch.amp.autocast('cuda', enabled=False):
A_d = A.double()
G = torch.bmm(A_d.transpose(1, 2), A_d)
eigenvalues, V = torch.linalg.eigh(G)
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
def svd_fp64(A):
"""Auto-dispatch SVD with fp64 internals."""
B, M, N = A.shape
if HAS_FL and N <= _FL_MAX_N and A.is_cuda:
orig_dtype = A.dtype
with torch.amp.autocast('cuda', enabled=False):
A_d = A.double()
G = torch.bmm(A_d.transpose(1, 2), A_d)
eigenvalues, V = FLEigh()(G.float())
eigenvalues = eigenvalues.double().flip(-1)
V = V.double().flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
else:
return gram_eigh_svd_fp64(A)
# ── CV Monitoring ────────────────────────────────────────────────
def cayley_menger_vol2(points):
"""Squared simplex volume via Cayley-Menger determinant, fp64."""
B, N, D = points.shape
pts = points.double()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
def cv_of(emb, n_samples=200):
"""CV of pentachoron volumes for a (V, D) embedding."""
if emb.dim() != 2 or emb.shape[0] < 5:
return 0.0
N, D = emb.shape
pool = min(N, 512)
indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)])
vol2 = cayley_menger_vol2(emb[:pool][indices])
valid = vol2 > 1e-20
if valid.sum() < 10:
return 0.0
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
# ── Data ─────────────────────────────────────────────────────────
def get_tiny_imagenet(batch_size=256):
"""TinyImageNet via HuggingFace: 200 classes, 64x64."""
from datasets import load_dataset
ds = load_dataset('zh-plus/tiny-imagenet')
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)),
])
class HFImageDataset(torch.utils.data.Dataset):
def __init__(self, hf_split, transform):
self.data = hf_split
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
img = item['image']
if img.mode != 'RGB':
img = img.convert('RGB')
return self.transform(img), item['label']
train_ds = HFImageDataset(ds['train'], transform)
val_ds = HFImageDataset(ds['valid'], transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader
def get_cifar10(batch_size=256):
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, test_loader
def get_imagenet_128(batch_size=128):
"""ImageNet-1K at 128Γ—128 via HuggingFace: 1000 classes, 1.28M train, 50K val.
Requires: pip install datasets
Requires: HF auth + ImageNet terms accepted on huggingface.co
"""
from datasets import load_dataset
ds = load_dataset('benjamin-paine/imagenet-1k-128x128')
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # ImageNet stats
])
class HFImageDataset(torch.utils.data.Dataset):
def __init__(self, hf_split, transform):
self.data = hf_split
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
img = item['image']
if img.mode != 'RGB':
img = img.convert('RGB')
return self.transform(img), item['label']
train_ds = HFImageDataset(ds['train'], transform)
val_ds = HFImageDataset(ds['validation'], transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
return train_loader, val_loader
# ── Patch Utilities ──────────────────────────────────────────────
def extract_patches(images, patch_size=32):
"""Split (B, 3, H, W) into patches of (B, n_patches, 3, ph, pw).
Returns patches and grid dims (gh, gw) for reconstruction.
"""
B, C, H, W = images.shape
ph, pw = patch_size, patch_size
gh, gw = H // ph, W // pw
# (B, C, gh, ph, gw, pw) β†’ (B, gh, gw, C, ph, pw) β†’ (B, n_patches, C*ph*pw)
patches = images.reshape(B, C, gh, ph, gw, pw)
patches = patches.permute(0, 2, 4, 1, 3, 5) # (B, gh, gw, C, ph, pw)
patches = patches.reshape(B, gh * gw, C * ph * pw)
return patches, gh, gw
def stitch_patches(patches, gh, gw, patch_size=32):
"""Reassemble (B, n_patches, C*ph*pw) into (B, C, H, W)."""
B = patches.shape[0]
C = 3
ph, pw = patch_size, patch_size
patches = patches.reshape(B, gh, gw, C, ph, pw)
patches = patches.permute(0, 3, 1, 4, 2, 5) # (B, C, gh, ph, gw, pw)
return patches.reshape(B, C, gh * ph, gw * pw)
class BoundarySmooth(nn.Module):
"""Lightweight post-stitch boundary refinement.
Two 3Γ—3 convs with residual connection. Operates on the full
stitched image. The receptive field (5Γ—5) spans patch boundaries
without reaching deep into patch interiors.
~600 params. Learns to blend seams without disrupting content.
"""
def __init__(self, channels=3, mid=16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(channels, mid, 3, padding=1),
nn.GELU(),
nn.Conv2d(mid, channels, 3, padding=1),
)
# Init near-zero so it starts as identity
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def forward(self, x):
return x + self.net(x)
# ── Spectral Cross-Attention ────────────────────────────────────
class SpectralCrossAttention(nn.Module):
"""Multiplicative spectral coordination with learnable per-mode alpha.
S_out = S * (1 + Ξ±_d * tanh(attention_output_d))
Each spectral mode d learns its own coordination strength Ξ±_d.
Alpha is parameterized through sigmoid for bounded [0, max_alpha] range:
Ξ±_d = max_alpha * sigmoid(alpha_logit_d)
This lets the model discover:
- Which modes need cross-patch coordination (high Ξ±)
- Which modes should stay independent (low Ξ±)
- The global coordination budget (sum of alphas, regularizable)
The alpha vector is a diagnostic: after training, it tells you
which spectral modes carry inter-patch structure.
"""
def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0):
super().__init__()
self.n_heads = n_heads
self.head_dim = D // n_heads
self.max_alpha = max_alpha
assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}"
self.qkv = nn.Linear(D, 3 * D)
self.out_proj = nn.Linear(D, D)
self.norm = nn.LayerNorm(D)
self.scale = self.head_dim ** -0.5
# Learnable per-mode alpha: initialized conservative (sigmoid(-2) β‰ˆ 0.12)
# so Ξ± starts at ~0.024 per mode (0.2 * 0.12)
self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
@property
def alpha(self):
"""Current per-mode alpha values, bounded [0, max_alpha]."""
return self.max_alpha * torch.sigmoid(self.alpha_logits)
def forward(self, S):
"""S: (B, n_patches, D) β†’ coordinated S: (B, n_patches, D)"""
B, N, D = S.shape
S_normed = self.norm(S)
qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
gate = torch.tanh(self.out_proj(out))
# Per-mode multiplicative modulation with learned strength
alpha = self.alpha # (D,)
return S * (1.0 + alpha.unsqueeze(0).unsqueeze(0) * gate)
# ── Patch SVAE ───────────────────────────────────────────────────
class PatchSVAE(nn.Module):
"""Patch-based SVD Autoencoder.
Image β†’ patches β†’ per-patch encode β†’ sphere normalize β†’ SVD β†’
cross-patch spectral attention β†’ per-patch decode β†’ stitch.
Each patch has its own geometric attractor determined by (V, D).
Cross-patch attention coordinates spectral magnitudes only.
U and V (directional structure) remain independent per patch.
Args:
matrix_v: Rows per patch matrix
D: Embedding dimension
patch_size: Spatial patch size (default 32 β†’ 4 patches for 64Γ—64)
hidden: Per-patch MLP hidden width
depth: Number of residual blocks in encoder and decoder (default 2)
n_cross_layers: Number of spectral cross-attention layers
"""
def __init__(self, matrix_v=256, D=24, patch_size=32, hidden=512,
depth=2, n_cross_layers=2):
super().__init__()
self.matrix_v = matrix_v
self.D = D
self.patch_size = patch_size
self.patch_dim = 3 * patch_size * patch_size # 3072 for 32Γ—32
self.mat_dim = matrix_v * D
self.n_cross_layers = n_cross_layers
# Per-patch encoder: project in β†’ residual blocks β†’ project out
self.enc_in = nn.Linear(self.patch_dim, hidden)
self.enc_blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(hidden),
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Linear(hidden, hidden),
) for _ in range(depth)
])
self.enc_out = nn.Linear(hidden, self.mat_dim)
# Per-patch decoder: project in β†’ residual blocks β†’ project out
self.dec_in = nn.Linear(self.mat_dim, hidden)
self.dec_blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(hidden),
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Linear(hidden, hidden),
) for _ in range(depth)
])
self.dec_out = nn.Linear(hidden, self.patch_dim)
nn.init.orthogonal_(self.enc_out.weight)
# Cross-patch spectral coordination
self.cross_attn = nn.ModuleList([
SpectralCrossAttention(D, n_heads=min(4, D))
for _ in range(n_cross_layers)
])
# Post-stitch boundary refinement
self.boundary_smooth = BoundarySmooth(channels=3, mid=16)
def encode_patches(self, patches):
"""Encode all patches in parallel.
patches: (B, n_patches, patch_dim)
Returns: per-patch SVD dicts + coordinated S.
"""
B, N, _ = patches.shape
# Flatten batch and patches for shared encoder
flat = patches.reshape(B * N, -1)
h = F.gelu(self.enc_in(flat))
for block in self.enc_blocks:
h = h + block(h) # residual
M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D)
M = F.normalize(M, dim=-1) # rows to S^(D-1)
U, S, Vt = svd_fp64(M)
# Reshape back to (B, N, ...)
U = U.reshape(B, N, self.matrix_v, self.D)
S = S.reshape(B, N, self.D)
Vt = Vt.reshape(B, N, self.D, self.D)
M = M.reshape(B, N, self.matrix_v, self.D)
# Cross-patch spectral coordination
S_coord = S
for layer in self.cross_attn:
S_coord = layer(S_coord)
return {
'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M,
}
def decode_patches(self, U, S, Vt):
"""Decode from coordinated SVD.
U: (B, N, V, D), S: (B, N, D), Vt: (B, N, D, D)
Returns: (B, N, patch_dim)
"""
B, N, V, D = U.shape
# Reconstruct per-patch matrices from coordinated S
U_flat = U.reshape(B * N, V, D)
S_flat = S.reshape(B * N, D)
Vt_flat = Vt.reshape(B * N, D, D)
M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1)))
for block in self.dec_blocks:
h = h + block(h) # residual
patches = self.dec_out(h)
return patches.reshape(B, N, -1)
def forward(self, images):
patches, gh, gw = extract_patches(images, self.patch_size)
svd = self.encode_patches(patches)
decoded_patches = self.decode_patches(svd['U'], svd['S'], svd['Vt'])
recon = stitch_patches(decoded_patches, gh, gw, self.patch_size)
recon = self.boundary_smooth(recon)
return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw}
@staticmethod
def effective_rank(S):
"""Shannon entropy effective rank. S: (*, D)."""
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
return (-(p * p.log()).sum(-1)).exp()
# ── Training ─────────────────────────────────────────────────────
def train(epochs=200, lr=1e-3, V=256, D=24, patch_size=32,
hidden=512, depth=2,
target_cv=0.125, cv_weight=0.3, boost=0.5, sigma=0.15,
n_cross_layers=2, dataset='tiny_imagenet', device='cuda',
save_dir='/content/checkpoints', save_every=50,
hf_repo='AbstractPhil/geolip-SVAE', hf_version='v11_prod',
tb_dir='/content/runs'):
"""Train the PatchSVAE with sphere normalization + soft hand.
Saves checkpoints to local + HuggingFace.
Logs to TensorBoard.
"""
import os
os.makedirs(save_dir, exist_ok=True)
# ── TensorBoard ──
from torch.utils.tensorboard import SummaryWriter
run_name = f"patchsvae_V{V}_D{D}_p{patch_size}_h{hidden}_d{depth}"
tb_path = os.path.join(tb_dir, run_name)
writer = SummaryWriter(tb_path)
print(f" TensorBoard: {tb_path}")
# ── HuggingFace ──
hf_enabled = False
try:
from huggingface_hub import HfApi
api = HfApi()
# Test auth
api.whoami()
hf_enabled = True
hf_prefix = f"{hf_version}/checkpoints"
print(f" HuggingFace: {hf_repo}/{hf_prefix}")
except Exception as e:
print(f" HuggingFace: disabled ({e})")
def upload_to_hf(local_path, remote_name):
"""Upload a file to HF repo, non-blocking on failure."""
if not hf_enabled:
return
try:
remote_path = f"{hf_prefix}/{remote_name}"
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=remote_path,
repo_id=hf_repo,
repo_type="model",
)
print(f" ☁️ Uploaded: {hf_repo}/{remote_path}")
except Exception as e:
print(f" ⚠️ HF upload failed: {e}")
device = torch.device(device if torch.cuda.is_available() else 'cpu')
if dataset == 'imagenet_128':
train_loader, test_loader = get_imagenet_128(batch_size=128)
img_h = 128
n_classes = 1000
mean_t = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).to(device)
std_t = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).to(device)
elif dataset == 'tiny_imagenet':
train_loader, test_loader = get_tiny_imagenet(batch_size=256)
img_h = 64
n_classes = 200
mean_t = torch.tensor([0.4802, 0.4481, 0.3975]).reshape(1, 3, 1, 1).to(device)
std_t = torch.tensor([0.2770, 0.2691, 0.2821]).reshape(1, 3, 1, 1).to(device)
else:
train_loader, test_loader = get_cifar10(batch_size=256)
img_h = 32
n_classes = 10
mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device)
std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device)
n_patches = (img_h // patch_size) ** 2
model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size,
hidden=hidden, depth=depth,
n_cross_layers=n_cross_layers).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
total_params = sum(p.numel() for p in model.parameters())
cross_params = sum(p.numel() for p in model.cross_attn.parameters())
svd_backend = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})"
print(f"Using geolip-core SVD ({svd_backend})")
print(f"PatchSVAE - {n_patches} patches of {patch_size}Γ—{patch_size}")
print(f" Dataset: {dataset} ({img_h}Γ—{img_h}, {n_classes} classes)")
print(f" Per-patch: ({V}, {D}) = {V*D} elements, rows on S^{D-1}")
print(f" Encoder/Decoder: hidden={hidden}, depth={depth} (residual blocks)")
print(f" Cross-attention: {n_cross_layers} layers on S vectors ({cross_params:,} params)")
print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far")
print(f" Total params: {total_params:,}")
print(f" Checkpoints: {save_dir} (best + every {save_every} epochs)")
print("=" * 95)
print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | "
f"{'t_rec':>7} | "
f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | "
f"{'row_cv':>7} {'prox':>5} {'rw':>5} | "
f"{'S_delta':>7}")
print("-" * 95)
best_recon = float('inf')
def save_checkpoint(path, epoch, test_mse, extra=None, upload=True):
"""Save model + optimizer + config. Optionally upload to HF."""
ckpt = {
'epoch': epoch,
'test_mse': test_mse,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'scheduler_state_dict': sched.state_dict(),
'config': {
'V': V, 'D': D, 'patch_size': patch_size,
'hidden': hidden, 'depth': depth,
'n_cross_layers': n_cross_layers,
'target_cv': target_cv, 'cv_weight': cv_weight,
'boost': boost, 'sigma': sigma,
'dataset': dataset, 'lr': lr,
},
}
if extra:
ckpt.update(extra)
torch.save(ckpt, path)
size_mb = os.path.getsize(path) / (1024 * 1024)
print(f" πŸ’Ύ Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})")
if upload:
upload_to_hf(path, os.path.basename(path))
for epoch in range(1, epochs + 1):
model.train()
total_loss, total_recon, n = 0, 0, 0
last_cv = target_cv
last_prox = 1.0
recon_w = 1.0 + boost
t0 = time.time()
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
# CV from first patch of first image
with torch.no_grad():
if batch_idx % 10 == 0:
current_cv = cv_of(out['svd']['M'][0, 0]) # first image, first patch
if current_cv > 0:
last_cv = current_cv
delta = last_cv - target_cv
last_prox = math.exp(-delta**2 / (2 * sigma**2))
recon_w = 1.0 + boost * last_prox
cv_pen = cv_weight * (1.0 - last_prox)
cv_l = (last_cv - target_cv) ** 2
loss = recon_w * recon_loss + cv_pen * cv_l
loss.backward()
# Clip cross-attention gradients only β€” the cascade source
torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5)
opt.step()
total_loss += loss.item() * len(images)
total_recon += recon_loss.item() * len(images)
n += len(images)
sched.step()
epoch_time = time.time() - t0
# ── TensorBoard: train metrics (every epoch, cheap) ──
writer.add_scalar('train/loss', total_loss / n, epoch)
writer.add_scalar('train/recon', total_recon / n, epoch)
writer.add_scalar('train/lr', sched.get_last_lr()[0], epoch)
writer.add_scalar('train/proximity', last_prox, epoch)
writer.add_scalar('train/recon_weight', recon_w, epoch)
writer.add_scalar('train/epoch_time', epoch_time, epoch)
# ── Evaluation ──
if epoch % 2 == 0 or epoch <= 3:
model.eval()
test_recon, test_n = 0, 0
test_S_orig, test_S_coord = None, None
test_erank = 0
row_cvs = []
nb = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
test_recon += F.mse_loss(out['recon'], images).item() * len(images)
test_n += len(images)
# Average over patches and batch
S_mean = out['svd']['S'].mean(dim=(0, 1)) # (D,)
S_orig_mean = out['svd']['S_orig'].mean(dim=(0, 1))
test_erank += model.effective_rank(out['svd']['S'].reshape(-1, D)).mean().item()
if nb < 3:
for b in range(min(2, len(images))):
for p in range(min(2, out['svd']['M'].shape[1])):
row_cvs.append(cv_of(out['svd']['M'][b, p]))
if test_S_orig is None:
test_S_orig = S_orig_mean.cpu()
test_S_coord = S_mean.cpu()
else:
test_S_orig += S_orig_mean.cpu()
test_S_coord += S_mean.cpu()
nb += 1
test_erank /= nb
test_S_orig /= nb
test_S_coord /= nb
ratio = (test_S_coord[0] / (test_S_coord[-1] + 1e-8)).item()
mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0
# How much did cross-attention change S?
s_delta = (test_S_coord - test_S_orig).abs().mean().item()
# Alpha diagnostics: mean and max across all cross-attention layers
with torch.no_grad():
all_alphas = torch.cat([layer.alpha for layer in model.cross_attn])
alpha_mean = all_alphas.mean().item()
alpha_max = all_alphas.max().item()
print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | "
f"{test_recon/test_n:7.4f} | "
f"{test_S_coord[0]:6.3f} {test_S_coord[-1]:6.3f} {ratio:5.2f} "
f"{test_erank:5.2f} | "
f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f} | "
f"{s_delta:7.5f} a:{alpha_mean:.4f}/{alpha_max:.4f}")
# ── TensorBoard: eval metrics + geometry ──
test_mse = test_recon / test_n
writer.add_scalar('test/recon_mse', test_mse, epoch)
writer.add_scalar('test/best_mse', min(best_recon, test_mse), epoch)
# Geometry
writer.add_scalar('geo/row_cv', mean_cv, epoch)
writer.add_scalar('geo/ratio', ratio, epoch)
writer.add_scalar('geo/erank', test_erank, epoch)
writer.add_scalar('geo/S0', test_S_coord[0].item(), epoch)
writer.add_scalar('geo/SD', test_S_coord[-1].item(), epoch)
# Cross-attention
writer.add_scalar('cross_attn/s_delta', s_delta, epoch)
writer.add_scalar('cross_attn/alpha_mean', alpha_mean, epoch)
writer.add_scalar('cross_attn/alpha_max', alpha_max, epoch)
# Soft hand dynamics
writer.add_scalar('soft_hand/proximity', last_prox, epoch)
writer.add_scalar('soft_hand/recon_weight', recon_w, epoch)
# Singular value profile (every 20 epochs β€” histogram is heavier)
if epoch % 20 == 0 or epoch <= 3:
writer.add_histogram('spectrum/S_coordinated', test_S_coord, epoch)
writer.add_histogram('spectrum/S_original', test_S_orig, epoch)
# Per-layer alpha
for li, layer in enumerate(model.cross_attn):
writer.add_histogram(f'alpha/layer_{li}', layer.alpha.detach().cpu(), epoch)
# Recon images (every 50 epochs β€” image writes are expensive)
if epoch % 50 == 0 or epoch == 1:
with torch.no_grad():
sample_imgs, _ = next(iter(test_loader))
sample_imgs = sample_imgs[:8].to(device)
sample_out = model(sample_imgs)
sample_recon = sample_out['recon']
# Denormalize for visualization
orig_vis = (sample_imgs * std_t + mean_t).clamp(0, 1)
recon_vis = (sample_recon * std_t + mean_t).clamp(0, 1)
comparison = torch.cat([orig_vis, recon_vis], dim=0)
grid = torchvision.utils.make_grid(comparison, nrow=8, padding=2)
writer.add_image('recon/comparison', grid, epoch)
# ── Checkpoint saving ──
geo_stats = {
'row_cv': mean_cv, 'ratio': ratio, 'erank': test_erank,
'S0': test_S_coord[0].item(), 'SD': test_S_coord[-1].item(),
's_delta': s_delta, 'alpha_mean': alpha_mean, 'alpha_max': alpha_max,
}
# Best model β€” save locally always, upload only at periodic intervals
if test_mse < best_recon:
best_recon = test_mse
save_checkpoint(
os.path.join(save_dir, 'best.pt'),
epoch, test_mse, extra={'geo': geo_stats},
upload=False) # local only
# Periodic save + upload best if improved
if epoch % save_every == 0:
save_checkpoint(
os.path.join(save_dir, f'epoch_{epoch:04d}.pt'),
epoch, test_mse, extra={'geo': geo_stats})
# Also push current best to HF
best_path = os.path.join(save_dir, 'best.pt')
if os.path.exists(best_path):
upload_to_hf(best_path, 'best.pt')
# Flush + upload TB logs
writer.flush()
if hf_enabled:
try:
api.upload_folder(
folder_path=tb_path,
path_in_repo=f"{hf_version}/tensorboard/{run_name}",
repo_id=hf_repo, repo_type="model",
)
print(f" ☁️ TB logs synced to {hf_repo}")
except Exception as e:
print(f" ⚠️ TB sync failed: {e}")
# ── Final Analysis ──
print()
print("=" * 90)
print("FINAL ANALYSIS")
print("=" * 90)
model.eval()
all_recon_err = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
out = model(images)
all_recon_err.append(
F.mse_loss(out['recon'], images, reduction='none')
.mean(dim=(1, 2, 3)).cpu())
all_recon_err = torch.cat(all_recon_err)
print(f"\n PatchSVAE: {n_patches} patches Γ— ({V}, {D})")
print(f" Target CV: {target_cv}")
print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}")
print(f" Row CV: {mean_cv:.4f}")
print(f" Cross-attention S delta: {s_delta:.5f}")
# Per-mode alpha profile β€” which modes coordinate between patches?
print(f"\n Learned alpha per mode (coordination strength):")
for layer_idx, layer in enumerate(model.cross_attn):
alpha = layer.alpha.detach().cpu()
print(f" Layer {layer_idx}: mean={alpha.mean():.4f} max={alpha.max():.4f} min={alpha.min():.4f}")
bar_scale = 40 / (alpha.max().item() + 1e-8)
for d in range(len(alpha)):
bar = "#" * int(alpha[d].item() * bar_scale)
print(f" Ξ±[{d:2d}]: {alpha[d]:.4f} {bar}")
print(f"\n Coordinated singular value profile:")
total_energy = (test_S_coord ** 2).sum()
cumulative = 0
for i in range(len(test_S_coord)):
e = (test_S_coord[i] ** 2).item()
cumulative += e
pct = cumulative / total_energy * 100
bar = "#" * int(test_S_coord[i].item() * 30 / (test_S_coord[0].item() + 1e-8))
print(f" S[{i:2d}]: {test_S_coord[i]:8.4f} cum={pct:5.1f}% {bar}")
# ── Reconstruction Grid ──
print(f"\n Saving reconstruction grid...")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
model.eval()
with torch.no_grad():
images, labels = next(iter(test_loader))
images = images[:20].to(device)
out = model(images)
recon = out['recon']
def denorm(t):
return (t * std_t + mean_t).clamp(0, 1).cpu()
n_show = min(10, len(images))
fig, axes = plt.subplots(n_show, 3, figsize=(6, n_show * 2))
for i in range(n_show):
axes[i, 0].imshow(denorm(images[i:i+1])[0].permute(1, 2, 0).numpy())
axes[i, 1].imshow(denorm(recon[i:i+1])[0].permute(1, 2, 0).numpy())
diff = (denorm(images[i:i+1]) - denorm(recon[i:i+1])).abs() * 5
axes[i, 2].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy())
axes[0, 0].set_title('Original', fontsize=8)
axes[0, 1].set_title('Recon', fontsize=8)
axes[0, 2].set_title('|Err|Γ—5', fontsize=8)
for ax in axes.flat:
ax.axis('off')
plt.tight_layout()
plt.savefig('/content/svae_patch_recon.png', dpi=200, bbox_inches='tight')
print(f" Saved to /content/svae_patch_recon.png")
try:
plt.show()
except:
pass
plt.close()
# ── Final checkpoint ──
save_checkpoint(
os.path.join(save_dir, 'final.pt'),
epochs, all_recon_err.mean().item(),
extra={'geo': geo_stats}
)
# ── Close TensorBoard + upload logs ──
writer.close()
if hf_enabled:
try:
api.upload_folder(
folder_path=tb_path,
path_in_repo=f"{hf_version}/tensorboard/{run_name}",
repo_id=hf_repo,
repo_type="model",
)
print(f" ☁️ TB logs uploaded to {hf_repo}/{hf_version}/tensorboard/")
except Exception as e:
print(f" ⚠️ TB upload failed: {e}")
# Upload recon grid
recon_grid_path = '/content/svae_patch_recon.png'
if os.path.exists(recon_grid_path):
upload_to_hf(recon_grid_path, 'recon_grid.png')
print(f"\n Best MSE: {best_recon:.6f}")
print(f" Checkpoints: {save_dir}/")
print(f" TensorBoard: {tb_path}")
print(f" HuggingFace: {hf_repo}/{hf_version}/")
if __name__ == "__main__":
# ImageNet-1K 128Γ—128: 64 patches of 16Γ—16, each (256, 16)
# 1000 classes, 1.28M train images
# depth=4 residual blocks, hidden=768, learned alpha coordination
train(epochs=200, lr=1e-4, V=256, D=16, patch_size=16,
hidden=768, depth=4,
target_cv=0.125, n_cross_layers=2,
dataset='imagenet_128',
hf_version='v12_imagenet128')