| """ |
| SVAE-Patch β Patch-based SVD Autoencoder |
| ========================================== |
| 64Γ64 image β 4 patches of 32Γ32 β independent SVD per patch β coordinate β decode. |
| |
| Each patch gets the proven (V, D) pipeline: |
| patch β MLP β M β β^(VΓD) β normalize β SVD β (U, S, Vt) |
| |
| Cross-patch coordination via a lightweight attention on the spectral |
| representations β each patch's S vector (D-dim) attends to all others. |
| This lets patches share information about relative spectral structure |
| without disrupting the per-patch geometric attractors. |
| |
| Reconstruction: coordinated spectra + per-patch (U, Vt) β MLP β patch β stitch. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as T |
| import math |
| import time |
|
|
| |
|
|
| try: |
| from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N |
| HAS_FL = True |
| except ImportError: |
| HAS_FL = False |
|
|
|
|
| def gram_eigh_svd_fp64(A): |
| """Thin SVD via Gram + eigh in fp64.""" |
| orig_dtype = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| A_d = A.double() |
| G = torch.bmm(A_d.transpose(1, 2), A_d) |
| eigenvalues, V = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| V = V.flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) |
|
|
|
|
| def svd_fp64(A): |
| """Auto-dispatch SVD with fp64 internals.""" |
| B, M, N = A.shape |
| if HAS_FL and N <= _FL_MAX_N and A.is_cuda: |
| orig_dtype = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| A_d = A.double() |
| G = torch.bmm(A_d.transpose(1, 2), A_d) |
| eigenvalues, V = FLEigh()(G.float()) |
| eigenvalues = eigenvalues.double().flip(-1) |
| V = V.double().flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) |
| else: |
| return gram_eigh_svd_fp64(A) |
|
|
|
|
| |
|
|
| def cayley_menger_vol2(points): |
| """Squared simplex volume via Cayley-Menger determinant, fp64.""" |
| B, N, D = points.shape |
| pts = points.double() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_of(emb, n_samples=200): |
| """CV of pentachoron volumes for a (V, D) embedding.""" |
| if emb.dim() != 2 or emb.shape[0] < 5: |
| return 0.0 |
| N, D = emb.shape |
| pool = min(N, 512) |
| indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) |
| vol2 = cayley_menger_vol2(emb[:pool][indices]) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return 0.0 |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| |
|
|
| def get_tiny_imagenet(batch_size=256): |
| """TinyImageNet via HuggingFace: 200 classes, 64x64.""" |
| from datasets import load_dataset |
|
|
| ds = load_dataset('zh-plus/tiny-imagenet') |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)), |
| ]) |
|
|
| class HFImageDataset(torch.utils.data.Dataset): |
| def __init__(self, hf_split, transform): |
| self.data = hf_split |
| self.transform = transform |
| def __len__(self): |
| return len(self.data) |
| def __getitem__(self, idx): |
| item = self.data[idx] |
| img = item['image'] |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| return self.transform(img), item['label'] |
|
|
| train_ds = HFImageDataset(ds['train'], transform) |
| val_ds = HFImageDataset(ds['valid'], transform) |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) |
| val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2) |
| return train_loader, val_loader |
|
|
|
|
| def get_cifar10(batch_size=256): |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
| train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) |
| test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) |
| return train_loader, test_loader |
|
|
|
|
| def get_imagenet_128(batch_size=128): |
| """ImageNet-1K at 128Γ128 via HuggingFace: 1000 classes, 1.28M train, 50K val. |
| Requires: pip install datasets |
| Requires: HF auth + ImageNet terms accepted on huggingface.co |
| """ |
| from datasets import load_dataset |
|
|
| ds = load_dataset('benjamin-paine/imagenet-1k-128x128') |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ]) |
|
|
| class HFImageDataset(torch.utils.data.Dataset): |
| def __init__(self, hf_split, transform): |
| self.data = hf_split |
| self.transform = transform |
| def __len__(self): |
| return len(self.data) |
| def __getitem__(self, idx): |
| item = self.data[idx] |
| img = item['image'] |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| return self.transform(img), item['label'] |
|
|
| train_ds = HFImageDataset(ds['train'], transform) |
| val_ds = HFImageDataset(ds['validation'], transform) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) |
| return train_loader, val_loader |
|
|
|
|
| |
|
|
| def extract_patches(images, patch_size=32): |
| """Split (B, 3, H, W) into patches of (B, n_patches, 3, ph, pw). |
| Returns patches and grid dims (gh, gw) for reconstruction. |
| """ |
| B, C, H, W = images.shape |
| ph, pw = patch_size, patch_size |
| gh, gw = H // ph, W // pw |
| |
| patches = images.reshape(B, C, gh, ph, gw, pw) |
| patches = patches.permute(0, 2, 4, 1, 3, 5) |
| patches = patches.reshape(B, gh * gw, C * ph * pw) |
| return patches, gh, gw |
|
|
|
|
| def stitch_patches(patches, gh, gw, patch_size=32): |
| """Reassemble (B, n_patches, C*ph*pw) into (B, C, H, W).""" |
| B = patches.shape[0] |
| C = 3 |
| ph, pw = patch_size, patch_size |
| patches = patches.reshape(B, gh, gw, C, ph, pw) |
| patches = patches.permute(0, 3, 1, 4, 2, 5) |
| return patches.reshape(B, C, gh * ph, gw * pw) |
|
|
|
|
| class BoundarySmooth(nn.Module): |
| """Lightweight post-stitch boundary refinement. |
| |
| Two 3Γ3 convs with residual connection. Operates on the full |
| stitched image. The receptive field (5Γ5) spans patch boundaries |
| without reaching deep into patch interiors. |
| |
| ~600 params. Learns to blend seams without disrupting content. |
| """ |
| def __init__(self, channels=3, mid=16): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(channels, mid, 3, padding=1), |
| nn.GELU(), |
| nn.Conv2d(mid, channels, 3, padding=1), |
| ) |
| |
| nn.init.zeros_(self.net[-1].weight) |
| nn.init.zeros_(self.net[-1].bias) |
|
|
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| |
|
|
| class SpectralCrossAttention(nn.Module): |
| """Multiplicative spectral coordination with learnable per-mode alpha. |
| |
| S_out = S * (1 + Ξ±_d * tanh(attention_output_d)) |
| |
| Each spectral mode d learns its own coordination strength Ξ±_d. |
| Alpha is parameterized through sigmoid for bounded [0, max_alpha] range: |
| Ξ±_d = max_alpha * sigmoid(alpha_logit_d) |
| |
| This lets the model discover: |
| - Which modes need cross-patch coordination (high Ξ±) |
| - Which modes should stay independent (low Ξ±) |
| - The global coordination budget (sum of alphas, regularizable) |
| |
| The alpha vector is a diagnostic: after training, it tells you |
| which spectral modes carry inter-patch structure. |
| """ |
| def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = D // n_heads |
| self.max_alpha = max_alpha |
| assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}" |
|
|
| self.qkv = nn.Linear(D, 3 * D) |
| self.out_proj = nn.Linear(D, D) |
| self.norm = nn.LayerNorm(D) |
| self.scale = self.head_dim ** -0.5 |
|
|
| |
| |
| self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) |
|
|
| @property |
| def alpha(self): |
| """Current per-mode alpha values, bounded [0, max_alpha].""" |
| return self.max_alpha * torch.sigmoid(self.alpha_logits) |
|
|
| def forward(self, S): |
| """S: (B, n_patches, D) β coordinated S: (B, n_patches, D)""" |
| B, N, D = S.shape |
| S_normed = self.norm(S) |
|
|
| qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, N, D) |
| gate = torch.tanh(self.out_proj(out)) |
|
|
| |
| alpha = self.alpha |
| return S * (1.0 + alpha.unsqueeze(0).unsqueeze(0) * gate) |
|
|
|
|
| |
|
|
| class PatchSVAE(nn.Module): |
| """Patch-based SVD Autoencoder. |
| |
| Image β patches β per-patch encode β sphere normalize β SVD β |
| cross-patch spectral attention β per-patch decode β stitch. |
| |
| Each patch has its own geometric attractor determined by (V, D). |
| Cross-patch attention coordinates spectral magnitudes only. |
| U and V (directional structure) remain independent per patch. |
| |
| Args: |
| matrix_v: Rows per patch matrix |
| D: Embedding dimension |
| patch_size: Spatial patch size (default 32 β 4 patches for 64Γ64) |
| hidden: Per-patch MLP hidden width |
| depth: Number of residual blocks in encoder and decoder (default 2) |
| n_cross_layers: Number of spectral cross-attention layers |
| """ |
| def __init__(self, matrix_v=256, D=24, patch_size=32, hidden=512, |
| depth=2, n_cross_layers=2): |
| super().__init__() |
| self.matrix_v = matrix_v |
| self.D = D |
| self.patch_size = patch_size |
| self.patch_dim = 3 * patch_size * patch_size |
| self.mat_dim = matrix_v * D |
| self.n_cross_layers = n_cross_layers |
|
|
| |
| self.enc_in = nn.Linear(self.patch_dim, hidden) |
| self.enc_blocks = nn.ModuleList([ |
| nn.Sequential( |
| nn.LayerNorm(hidden), |
| nn.Linear(hidden, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, hidden), |
| ) for _ in range(depth) |
| ]) |
| self.enc_out = nn.Linear(hidden, self.mat_dim) |
|
|
| |
| self.dec_in = nn.Linear(self.mat_dim, hidden) |
| self.dec_blocks = nn.ModuleList([ |
| nn.Sequential( |
| nn.LayerNorm(hidden), |
| nn.Linear(hidden, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, hidden), |
| ) for _ in range(depth) |
| ]) |
| self.dec_out = nn.Linear(hidden, self.patch_dim) |
|
|
| nn.init.orthogonal_(self.enc_out.weight) |
|
|
| |
| self.cross_attn = nn.ModuleList([ |
| SpectralCrossAttention(D, n_heads=min(4, D)) |
| for _ in range(n_cross_layers) |
| ]) |
|
|
| |
| self.boundary_smooth = BoundarySmooth(channels=3, mid=16) |
|
|
| def encode_patches(self, patches): |
| """Encode all patches in parallel. |
| patches: (B, n_patches, patch_dim) |
| Returns: per-patch SVD dicts + coordinated S. |
| """ |
| B, N, _ = patches.shape |
|
|
| |
| flat = patches.reshape(B * N, -1) |
| h = F.gelu(self.enc_in(flat)) |
| for block in self.enc_blocks: |
| h = h + block(h) |
| M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D) |
| M = F.normalize(M, dim=-1) |
|
|
| U, S, Vt = svd_fp64(M) |
|
|
| |
| U = U.reshape(B, N, self.matrix_v, self.D) |
| S = S.reshape(B, N, self.D) |
| Vt = Vt.reshape(B, N, self.D, self.D) |
| M = M.reshape(B, N, self.matrix_v, self.D) |
|
|
| |
| S_coord = S |
| for layer in self.cross_attn: |
| S_coord = layer(S_coord) |
|
|
| return { |
| 'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M, |
| } |
|
|
| def decode_patches(self, U, S, Vt): |
| """Decode from coordinated SVD. |
| U: (B, N, V, D), S: (B, N, D), Vt: (B, N, D, D) |
| Returns: (B, N, patch_dim) |
| """ |
| B, N, V, D = U.shape |
|
|
| |
| U_flat = U.reshape(B * N, V, D) |
| S_flat = S.reshape(B * N, D) |
| Vt_flat = Vt.reshape(B * N, D, D) |
|
|
| M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) |
| h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1))) |
| for block in self.dec_blocks: |
| h = h + block(h) |
| patches = self.dec_out(h) |
| return patches.reshape(B, N, -1) |
|
|
| def forward(self, images): |
| patches, gh, gw = extract_patches(images, self.patch_size) |
| svd = self.encode_patches(patches) |
| decoded_patches = self.decode_patches(svd['U'], svd['S'], svd['Vt']) |
| recon = stitch_patches(decoded_patches, gh, gw, self.patch_size) |
| recon = self.boundary_smooth(recon) |
| return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw} |
|
|
| @staticmethod |
| def effective_rank(S): |
| """Shannon entropy effective rank. S: (*, D).""" |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| return (-(p * p.log()).sum(-1)).exp() |
|
|
|
|
| |
|
|
| def train(epochs=200, lr=1e-3, V=256, D=24, patch_size=32, |
| hidden=512, depth=2, |
| target_cv=0.125, cv_weight=0.3, boost=0.5, sigma=0.15, |
| n_cross_layers=2, dataset='tiny_imagenet', device='cuda', |
| save_dir='/content/checkpoints', save_every=50, |
| hf_repo='AbstractPhil/geolip-SVAE', hf_version='v11_prod', |
| tb_dir='/content/runs'): |
| """Train the PatchSVAE with sphere normalization + soft hand. |
| |
| Saves checkpoints to local + HuggingFace. |
| Logs to TensorBoard. |
| """ |
| import os |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| from torch.utils.tensorboard import SummaryWriter |
| run_name = f"patchsvae_V{V}_D{D}_p{patch_size}_h{hidden}_d{depth}" |
| tb_path = os.path.join(tb_dir, run_name) |
| writer = SummaryWriter(tb_path) |
| print(f" TensorBoard: {tb_path}") |
|
|
| |
| hf_enabled = False |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| |
| api.whoami() |
| hf_enabled = True |
| hf_prefix = f"{hf_version}/checkpoints" |
| print(f" HuggingFace: {hf_repo}/{hf_prefix}") |
| except Exception as e: |
| print(f" HuggingFace: disabled ({e})") |
|
|
| def upload_to_hf(local_path, remote_name): |
| """Upload a file to HF repo, non-blocking on failure.""" |
| if not hf_enabled: |
| return |
| try: |
| remote_path = f"{hf_prefix}/{remote_name}" |
| api.upload_file( |
| path_or_fileobj=local_path, |
| path_in_repo=remote_path, |
| repo_id=hf_repo, |
| repo_type="model", |
| ) |
| print(f" βοΈ Uploaded: {hf_repo}/{remote_path}") |
| except Exception as e: |
| print(f" β οΈ HF upload failed: {e}") |
| device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
| if dataset == 'imagenet_128': |
| train_loader, test_loader = get_imagenet_128(batch_size=128) |
| img_h = 128 |
| n_classes = 1000 |
| mean_t = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).to(device) |
| std_t = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).to(device) |
| elif dataset == 'tiny_imagenet': |
| train_loader, test_loader = get_tiny_imagenet(batch_size=256) |
| img_h = 64 |
| n_classes = 200 |
| mean_t = torch.tensor([0.4802, 0.4481, 0.3975]).reshape(1, 3, 1, 1).to(device) |
| std_t = torch.tensor([0.2770, 0.2691, 0.2821]).reshape(1, 3, 1, 1).to(device) |
| else: |
| train_loader, test_loader = get_cifar10(batch_size=256) |
| img_h = 32 |
| n_classes = 10 |
| mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device) |
| std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device) |
|
|
| n_patches = (img_h // patch_size) ** 2 |
| model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size, |
| hidden=hidden, depth=depth, |
| n_cross_layers=n_cross_layers).to(device) |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| cross_params = sum(p.numel() for p in model.cross_attn.parameters()) |
|
|
| svd_backend = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})" |
| print(f"Using geolip-core SVD ({svd_backend})") |
| print(f"PatchSVAE - {n_patches} patches of {patch_size}Γ{patch_size}") |
| print(f" Dataset: {dataset} ({img_h}Γ{img_h}, {n_classes} classes)") |
| print(f" Per-patch: ({V}, {D}) = {V*D} elements, rows on S^{D-1}") |
| print(f" Encoder/Decoder: hidden={hidden}, depth={depth} (residual blocks)") |
| print(f" Cross-attention: {n_cross_layers} layers on S vectors ({cross_params:,} params)") |
| print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far") |
| print(f" Total params: {total_params:,}") |
| print(f" Checkpoints: {save_dir} (best + every {save_every} epochs)") |
| print("=" * 95) |
| print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | " |
| f"{'t_rec':>7} | " |
| f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " |
| f"{'row_cv':>7} {'prox':>5} {'rw':>5} | " |
| f"{'S_delta':>7}") |
| print("-" * 95) |
|
|
| best_recon = float('inf') |
|
|
| def save_checkpoint(path, epoch, test_mse, extra=None, upload=True): |
| """Save model + optimizer + config. Optionally upload to HF.""" |
| ckpt = { |
| 'epoch': epoch, |
| 'test_mse': test_mse, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': opt.state_dict(), |
| 'scheduler_state_dict': sched.state_dict(), |
| 'config': { |
| 'V': V, 'D': D, 'patch_size': patch_size, |
| 'hidden': hidden, 'depth': depth, |
| 'n_cross_layers': n_cross_layers, |
| 'target_cv': target_cv, 'cv_weight': cv_weight, |
| 'boost': boost, 'sigma': sigma, |
| 'dataset': dataset, 'lr': lr, |
| }, |
| } |
| if extra: |
| ckpt.update(extra) |
| torch.save(ckpt, path) |
| size_mb = os.path.getsize(path) / (1024 * 1024) |
| print(f" πΎ Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})") |
| if upload: |
| upload_to_hf(path, os.path.basename(path)) |
|
|
| for epoch in range(1, epochs + 1): |
| model.train() |
| total_loss, total_recon, n = 0, 0, 0 |
| last_cv = target_cv |
| last_prox = 1.0 |
| recon_w = 1.0 + boost |
| t0 = time.time() |
|
|
| for batch_idx, (images, labels) in enumerate(train_loader): |
| images = images.to(device) |
| opt.zero_grad() |
| out = model(images) |
| recon_loss = F.mse_loss(out['recon'], images) |
|
|
| |
| with torch.no_grad(): |
| if batch_idx % 10 == 0: |
| current_cv = cv_of(out['svd']['M'][0, 0]) |
| if current_cv > 0: |
| last_cv = current_cv |
| delta = last_cv - target_cv |
| last_prox = math.exp(-delta**2 / (2 * sigma**2)) |
|
|
| recon_w = 1.0 + boost * last_prox |
| cv_pen = cv_weight * (1.0 - last_prox) |
| cv_l = (last_cv - target_cv) ** 2 |
|
|
| loss = recon_w * recon_loss + cv_pen * cv_l |
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5) |
|
|
| opt.step() |
|
|
| total_loss += loss.item() * len(images) |
| total_recon += recon_loss.item() * len(images) |
| n += len(images) |
|
|
| sched.step() |
| epoch_time = time.time() - t0 |
|
|
| |
| writer.add_scalar('train/loss', total_loss / n, epoch) |
| writer.add_scalar('train/recon', total_recon / n, epoch) |
| writer.add_scalar('train/lr', sched.get_last_lr()[0], epoch) |
| writer.add_scalar('train/proximity', last_prox, epoch) |
| writer.add_scalar('train/recon_weight', recon_w, epoch) |
| writer.add_scalar('train/epoch_time', epoch_time, epoch) |
|
|
| |
| if epoch % 2 == 0 or epoch <= 3: |
| model.eval() |
| test_recon, test_n = 0, 0 |
| test_S_orig, test_S_coord = None, None |
| test_erank = 0 |
| row_cvs = [] |
| nb = 0 |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| out = model(images) |
| test_recon += F.mse_loss(out['recon'], images).item() * len(images) |
| test_n += len(images) |
|
|
| |
| S_mean = out['svd']['S'].mean(dim=(0, 1)) |
| S_orig_mean = out['svd']['S_orig'].mean(dim=(0, 1)) |
| test_erank += model.effective_rank(out['svd']['S'].reshape(-1, D)).mean().item() |
|
|
| if nb < 3: |
| for b in range(min(2, len(images))): |
| for p in range(min(2, out['svd']['M'].shape[1])): |
| row_cvs.append(cv_of(out['svd']['M'][b, p])) |
|
|
| if test_S_orig is None: |
| test_S_orig = S_orig_mean.cpu() |
| test_S_coord = S_mean.cpu() |
| else: |
| test_S_orig += S_orig_mean.cpu() |
| test_S_coord += S_mean.cpu() |
| nb += 1 |
|
|
| test_erank /= nb |
| test_S_orig /= nb |
| test_S_coord /= nb |
| ratio = (test_S_coord[0] / (test_S_coord[-1] + 1e-8)).item() |
| mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0 |
|
|
| |
| s_delta = (test_S_coord - test_S_orig).abs().mean().item() |
|
|
| |
| with torch.no_grad(): |
| all_alphas = torch.cat([layer.alpha for layer in model.cross_attn]) |
| alpha_mean = all_alphas.mean().item() |
| alpha_max = all_alphas.max().item() |
|
|
| print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | " |
| f"{test_recon/test_n:7.4f} | " |
| f"{test_S_coord[0]:6.3f} {test_S_coord[-1]:6.3f} {ratio:5.2f} " |
| f"{test_erank:5.2f} | " |
| f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f} | " |
| f"{s_delta:7.5f} a:{alpha_mean:.4f}/{alpha_max:.4f}") |
|
|
| |
| test_mse = test_recon / test_n |
| writer.add_scalar('test/recon_mse', test_mse, epoch) |
| writer.add_scalar('test/best_mse', min(best_recon, test_mse), epoch) |
|
|
| |
| writer.add_scalar('geo/row_cv', mean_cv, epoch) |
| writer.add_scalar('geo/ratio', ratio, epoch) |
| writer.add_scalar('geo/erank', test_erank, epoch) |
| writer.add_scalar('geo/S0', test_S_coord[0].item(), epoch) |
| writer.add_scalar('geo/SD', test_S_coord[-1].item(), epoch) |
|
|
| |
| writer.add_scalar('cross_attn/s_delta', s_delta, epoch) |
| writer.add_scalar('cross_attn/alpha_mean', alpha_mean, epoch) |
| writer.add_scalar('cross_attn/alpha_max', alpha_max, epoch) |
|
|
| |
| writer.add_scalar('soft_hand/proximity', last_prox, epoch) |
| writer.add_scalar('soft_hand/recon_weight', recon_w, epoch) |
|
|
| |
| if epoch % 20 == 0 or epoch <= 3: |
| writer.add_histogram('spectrum/S_coordinated', test_S_coord, epoch) |
| writer.add_histogram('spectrum/S_original', test_S_orig, epoch) |
| |
| for li, layer in enumerate(model.cross_attn): |
| writer.add_histogram(f'alpha/layer_{li}', layer.alpha.detach().cpu(), epoch) |
|
|
| |
| if epoch % 50 == 0 or epoch == 1: |
| with torch.no_grad(): |
| sample_imgs, _ = next(iter(test_loader)) |
| sample_imgs = sample_imgs[:8].to(device) |
| sample_out = model(sample_imgs) |
| sample_recon = sample_out['recon'] |
| |
| orig_vis = (sample_imgs * std_t + mean_t).clamp(0, 1) |
| recon_vis = (sample_recon * std_t + mean_t).clamp(0, 1) |
| comparison = torch.cat([orig_vis, recon_vis], dim=0) |
| grid = torchvision.utils.make_grid(comparison, nrow=8, padding=2) |
| writer.add_image('recon/comparison', grid, epoch) |
|
|
| |
| geo_stats = { |
| 'row_cv': mean_cv, 'ratio': ratio, 'erank': test_erank, |
| 'S0': test_S_coord[0].item(), 'SD': test_S_coord[-1].item(), |
| 's_delta': s_delta, 'alpha_mean': alpha_mean, 'alpha_max': alpha_max, |
| } |
|
|
| |
| if test_mse < best_recon: |
| best_recon = test_mse |
| save_checkpoint( |
| os.path.join(save_dir, 'best.pt'), |
| epoch, test_mse, extra={'geo': geo_stats}, |
| upload=False) |
|
|
| |
| if epoch % save_every == 0: |
| save_checkpoint( |
| os.path.join(save_dir, f'epoch_{epoch:04d}.pt'), |
| epoch, test_mse, extra={'geo': geo_stats}) |
| |
| best_path = os.path.join(save_dir, 'best.pt') |
| if os.path.exists(best_path): |
| upload_to_hf(best_path, 'best.pt') |
| |
| writer.flush() |
| if hf_enabled: |
| try: |
| api.upload_folder( |
| folder_path=tb_path, |
| path_in_repo=f"{hf_version}/tensorboard/{run_name}", |
| repo_id=hf_repo, repo_type="model", |
| ) |
| print(f" βοΈ TB logs synced to {hf_repo}") |
| except Exception as e: |
| print(f" β οΈ TB sync failed: {e}") |
|
|
| |
| print() |
| print("=" * 90) |
| print("FINAL ANALYSIS") |
| print("=" * 90) |
|
|
| model.eval() |
| all_recon_err = [] |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| out = model(images) |
| all_recon_err.append( |
| F.mse_loss(out['recon'], images, reduction='none') |
| .mean(dim=(1, 2, 3)).cpu()) |
|
|
| all_recon_err = torch.cat(all_recon_err) |
|
|
| print(f"\n PatchSVAE: {n_patches} patches Γ ({V}, {D})") |
| print(f" Target CV: {target_cv}") |
| print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}") |
| print(f" Row CV: {mean_cv:.4f}") |
| print(f" Cross-attention S delta: {s_delta:.5f}") |
|
|
| |
| print(f"\n Learned alpha per mode (coordination strength):") |
| for layer_idx, layer in enumerate(model.cross_attn): |
| alpha = layer.alpha.detach().cpu() |
| print(f" Layer {layer_idx}: mean={alpha.mean():.4f} max={alpha.max():.4f} min={alpha.min():.4f}") |
| bar_scale = 40 / (alpha.max().item() + 1e-8) |
| for d in range(len(alpha)): |
| bar = "#" * int(alpha[d].item() * bar_scale) |
| print(f" Ξ±[{d:2d}]: {alpha[d]:.4f} {bar}") |
|
|
| print(f"\n Coordinated singular value profile:") |
| total_energy = (test_S_coord ** 2).sum() |
| cumulative = 0 |
| for i in range(len(test_S_coord)): |
| e = (test_S_coord[i] ** 2).item() |
| cumulative += e |
| pct = cumulative / total_energy * 100 |
| bar = "#" * int(test_S_coord[i].item() * 30 / (test_S_coord[0].item() + 1e-8)) |
| print(f" S[{i:2d}]: {test_S_coord[i]:8.4f} cum={pct:5.1f}% {bar}") |
|
|
| |
| print(f"\n Saving reconstruction grid...") |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| model.eval() |
| with torch.no_grad(): |
| images, labels = next(iter(test_loader)) |
| images = images[:20].to(device) |
| out = model(images) |
| recon = out['recon'] |
|
|
| def denorm(t): |
| return (t * std_t + mean_t).clamp(0, 1).cpu() |
|
|
| n_show = min(10, len(images)) |
| fig, axes = plt.subplots(n_show, 3, figsize=(6, n_show * 2)) |
| for i in range(n_show): |
| axes[i, 0].imshow(denorm(images[i:i+1])[0].permute(1, 2, 0).numpy()) |
| axes[i, 1].imshow(denorm(recon[i:i+1])[0].permute(1, 2, 0).numpy()) |
| diff = (denorm(images[i:i+1]) - denorm(recon[i:i+1])).abs() * 5 |
| axes[i, 2].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy()) |
| axes[0, 0].set_title('Original', fontsize=8) |
| axes[0, 1].set_title('Recon', fontsize=8) |
| axes[0, 2].set_title('|Err|Γ5', fontsize=8) |
| for ax in axes.flat: |
| ax.axis('off') |
| plt.tight_layout() |
| plt.savefig('/content/svae_patch_recon.png', dpi=200, bbox_inches='tight') |
| print(f" Saved to /content/svae_patch_recon.png") |
| try: |
| plt.show() |
| except: |
| pass |
| plt.close() |
|
|
| |
| save_checkpoint( |
| os.path.join(save_dir, 'final.pt'), |
| epochs, all_recon_err.mean().item(), |
| extra={'geo': geo_stats} |
| ) |
|
|
| |
| writer.close() |
| if hf_enabled: |
| try: |
| api.upload_folder( |
| folder_path=tb_path, |
| path_in_repo=f"{hf_version}/tensorboard/{run_name}", |
| repo_id=hf_repo, |
| repo_type="model", |
| ) |
| print(f" βοΈ TB logs uploaded to {hf_repo}/{hf_version}/tensorboard/") |
| except Exception as e: |
| print(f" β οΈ TB upload failed: {e}") |
|
|
| |
| recon_grid_path = '/content/svae_patch_recon.png' |
| if os.path.exists(recon_grid_path): |
| upload_to_hf(recon_grid_path, 'recon_grid.png') |
|
|
| print(f"\n Best MSE: {best_recon:.6f}") |
| print(f" Checkpoints: {save_dir}/") |
| print(f" TensorBoard: {tb_path}") |
| print(f" HuggingFace: {hf_repo}/{hf_version}/") |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| train(epochs=200, lr=1e-4, V=256, D=16, patch_size=16, |
| hidden=768, depth=4, |
| target_cv=0.125, n_cross_layers=2, |
| dataset='imagenet_128', |
| hf_version='v12_imagenet128') |