| """ |
| Johanna-128 Omega β Continue Gaussian-pretrained model on 16 noise types |
| ========================================================================== |
| Loads the Gaussian-trained checkpoint (ep200, MSE=0.059) and expands |
| the signal vocabulary to all 16 noise distributions at 128Γ128. |
| |
| The Gaussian knowledge is the foundation β the MLP already knows how to |
| invert the geometric projection for one distribution. Now we teach it |
| the other 15 without destroying what it learned. |
| |
| Strategy: |
| - Moderate lr (1e-4): fast enough to learn new distributions, |
| slow enough to preserve Gaussian knowledge |
| - Gaussian is 1 of 16 types, so it stays in the training mix |
| - Same architecture: V=256, D=16, hidden=768, depth=4, 17M params |
| - Batch=128 to stay under cusolver limit (128 Γ 64 patches = 8192 calls) |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as T |
| import math |
| import time |
| import numpy as np |
| from tqdm import tqdm |
|
|
| |
| try: |
| from google.colab import userdata |
| os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"]) |
| except Exception: |
| pass |
|
|
| |
|
|
| try: |
| from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N |
| HAS_FL = True |
| except ImportError: |
| HAS_FL = False |
|
|
|
|
| def gram_eigh_svd_fp64(A): |
| orig_dtype = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| A_d = A.double() |
| G = torch.bmm(A_d.transpose(1, 2), A_d) |
| G.diagonal(dim1=-2, dim2=-1).add_(1e-12) |
| eigenvalues, V = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| V = V.flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) |
|
|
|
|
| def svd_fp64(A): |
| B, M, N = A.shape |
| if HAS_FL and N <= _FL_MAX_N and A.is_cuda: |
| orig_dtype = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| A_d = A.double() |
| G = torch.bmm(A_d.transpose(1, 2), A_d) |
| eigenvalues, V = FLEigh()(G.float()) |
| eigenvalues = eigenvalues.double().flip(-1) |
| V = V.double().flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) |
| else: |
| return gram_eigh_svd_fp64(A) |
|
|
|
|
| |
|
|
| def cayley_menger_vol2(points): |
| B, N, D = points.shape |
| pts = points.double() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_of(emb, n_samples=200): |
| if emb.dim() != 2 or emb.shape[0] < 5: |
| return 0.0 |
| N, D = emb.shape |
| pool = min(N, 512) |
| indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) |
| vol2 = cayley_menger_vol2(emb[:pool][indices]) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return 0.0 |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| |
|
|
| class OmegaNoiseDataset(torch.utils.data.Dataset): |
| """16 noise types at arbitrary resolution. Seed rotation.""" |
|
|
| N_TYPES = 16 |
|
|
| def __init__(self, size=1000000, img_size=128, seed_rotate_every=1000): |
| self.size = size |
| self.img_size = img_size |
| self.seed_rotate_every = seed_rotate_every |
| self._rng = np.random.RandomState(42) |
| self._call_count = 0 |
|
|
| def __len__(self): |
| return self.size |
|
|
| def _rotate_seed(self): |
| self._call_count += 1 |
| if self._call_count % self.seed_rotate_every == 0: |
| new_seed = int.from_bytes(os.urandom(4), 'big') |
| self._rng = np.random.RandomState(new_seed) |
| torch.manual_seed(new_seed) |
|
|
| def _pink_noise(self, shape): |
| white = torch.randn(shape) |
| S = torch.fft.rfft2(white) |
| h, w = shape[-2], shape[-1] |
| fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1) |
| fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1) |
| f = torch.sqrt(fx**2 + fy**2).clamp(min=1e-8) |
| S = S / f |
| return torch.fft.irfft2(S, s=(h, w)) |
|
|
| def _brown_noise(self, shape): |
| white = torch.randn(shape) |
| S = torch.fft.rfft2(white) |
| h, w = shape[-2], shape[-1] |
| fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1) |
| fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1) |
| f = (fx**2 + fy**2).clamp(min=1e-8) |
| S = S / f |
| return torch.fft.irfft2(S, s=(h, w)) |
|
|
| def __getitem__(self, idx): |
| self._rotate_seed() |
| s = self.img_size |
| noise_type = idx % self.N_TYPES |
|
|
| if noise_type == 0: |
| img = torch.randn(3, s, s) |
| elif noise_type == 1: |
| img = torch.rand(3, s, s) * 2 - 1 |
| elif noise_type == 2: |
| img = (torch.rand(3, s, s) - 0.5) * 4 |
| elif noise_type == 3: |
| lam = self._rng.uniform(0.5, 20.0) |
| img = torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0 |
| elif noise_type == 4: |
| img = self._pink_noise((3, s, s)) |
| img = img / (img.std() + 1e-8) |
| elif noise_type == 5: |
| img = self._brown_noise((3, s, s)) |
| img = img / (img.std() + 1e-8) |
| elif noise_type == 6: |
| img = torch.where(torch.rand(3, s, s) > 0.5, |
| torch.ones(3, s, s) * 2, torch.ones(3, s, s) * -2) |
| img = img + torch.randn(3, s, s) * 0.1 |
| elif noise_type == 7: |
| mask = torch.rand(3, s, s) > 0.9 |
| img = torch.randn(3, s, s) * mask.float() * 3 |
| elif noise_type == 8: |
| block = self._rng.randint(2, 16) |
| small = torch.randn(3, s // block + 1, s // block + 1) |
| img = F.interpolate(small.unsqueeze(0), size=s, mode='nearest').squeeze(0) |
| elif noise_type == 9: |
| gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s) |
| gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s) |
| angle = self._rng.uniform(0, 2 * math.pi) |
| grad = math.cos(angle) * gx + math.sin(angle) * gy |
| img = grad.unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5 |
| elif noise_type == 10: |
| check_size = self._rng.randint(2, 16) |
| coords_y = torch.arange(s) // check_size |
| coords_x = torch.arange(s) // check_size |
| checker = ((coords_y.unsqueeze(1) + coords_x.unsqueeze(0)) % 2).float() * 2 - 1 |
| img = checker.unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.3 |
| elif noise_type == 11: |
| a = torch.randn(3, s, s) |
| b = torch.rand(3, s, s) * 2 - 1 |
| alpha = self._rng.uniform(0.2, 0.8) |
| img = alpha * a + (1 - alpha) * b |
| elif noise_type == 12: |
| img = torch.zeros(3, s, s) |
| h2, w2 = s // 2, s // 2 |
| img[:, :h2, :w2] = torch.randn(3, h2, w2) |
| img[:, :h2, w2:] = torch.rand(3, h2, w2) * 2 - 1 |
| img[:, h2:, :w2] = self._pink_noise((3, h2, w2)) / 2 |
| sp = torch.where(torch.rand(3, h2, w2) > 0.5, |
| torch.ones(3, h2, w2), -torch.ones(3, h2, w2)) |
| img[:, h2:, w2:] = sp |
| elif noise_type == 13: |
| u = torch.rand(3, s, s) |
| img = torch.tan(math.pi * (u - 0.5)) |
| img = img.clamp(-3, 3) |
| elif noise_type == 14: |
| img = torch.empty(3, s, s).exponential_(1.0) - 1.0 |
| elif noise_type == 15: |
| u = torch.rand(3, s, s) - 0.5 |
| img = -torch.sign(u) * torch.log1p(-2 * u.abs()) |
|
|
| img = img.clamp(-4, 4) |
| return img.float(), noise_type |
|
|
|
|
| |
|
|
| def extract_patches(images, patch_size=16): |
| B, C, H, W = images.shape |
| gh, gw = H // patch_size, W // patch_size |
| patches = images.reshape(B, C, gh, patch_size, gw, patch_size) |
| patches = patches.permute(0, 2, 4, 1, 3, 5) |
| return patches.reshape(B, gh * gw, C * patch_size * patch_size), gh, gw |
|
|
|
|
| def stitch_patches(patches, gh, gw, patch_size=16): |
| B = patches.shape[0] |
| patches = patches.reshape(B, gh, gw, 3, patch_size, patch_size) |
| patches = patches.permute(0, 3, 1, 4, 2, 5) |
| return patches.reshape(B, 3, gh * patch_size, gw * patch_size) |
|
|
|
|
| class BoundarySmooth(nn.Module): |
| def __init__(self, channels=3, mid=16): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(), |
| nn.Conv2d(mid, channels, 3, padding=1)) |
| nn.init.zeros_(self.net[-1].weight) |
| nn.init.zeros_(self.net[-1].bias) |
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| class SpectralCrossAttention(nn.Module): |
| def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = D // n_heads |
| self.max_alpha = max_alpha |
| assert D % n_heads == 0 |
| self.qkv = nn.Linear(D, 3 * D) |
| self.out_proj = nn.Linear(D, D) |
| self.norm = nn.LayerNorm(D) |
| self.scale = self.head_dim ** -0.5 |
| self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) |
|
|
| @property |
| def alpha(self): |
| return self.max_alpha * torch.sigmoid(self.alpha_logits) |
|
|
| def forward(self, S): |
| B, N, D = S.shape |
| S_normed = self.norm(S) |
| qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, N, D) |
| gate = torch.tanh(self.out_proj(out)) |
| return S * (1.0 + self.alpha.unsqueeze(0).unsqueeze(0) * gate) |
|
|
|
|
| class PatchSVAE(nn.Module): |
| def __init__(self, matrix_v=256, D=16, patch_size=16, hidden=768, |
| depth=4, n_cross_layers=2): |
| super().__init__() |
| self.matrix_v = matrix_v |
| self.D = D |
| self.patch_size = patch_size |
| self.patch_dim = 3 * patch_size * patch_size |
| self.mat_dim = matrix_v * D |
|
|
| self.enc_in = nn.Linear(self.patch_dim, hidden) |
| self.enc_blocks = nn.ModuleList([ |
| nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden), |
| nn.GELU(), nn.Linear(hidden, hidden)) |
| for _ in range(depth)]) |
| self.enc_out = nn.Linear(hidden, self.mat_dim) |
|
|
| self.dec_in = nn.Linear(self.mat_dim, hidden) |
| self.dec_blocks = nn.ModuleList([ |
| nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden), |
| nn.GELU(), nn.Linear(hidden, hidden)) |
| for _ in range(depth)]) |
| self.dec_out = nn.Linear(hidden, self.patch_dim) |
|
|
| nn.init.orthogonal_(self.enc_out.weight) |
|
|
| self.cross_attn = nn.ModuleList([ |
| SpectralCrossAttention(D, n_heads=min(4, D)) |
| for _ in range(n_cross_layers)]) |
| self.boundary_smooth = BoundarySmooth(channels=3, mid=16) |
|
|
| def encode_patches(self, patches): |
| B, N, _ = patches.shape |
| flat = patches.reshape(B * N, -1) |
| h = F.gelu(self.enc_in(flat)) |
| for block in self.enc_blocks: |
| h = h + block(h) |
| M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D) |
| M = F.normalize(M, dim=-1) |
| U, S, Vt = svd_fp64(M) |
| U = U.reshape(B, N, self.matrix_v, self.D) |
| S = S.reshape(B, N, self.D) |
| Vt = Vt.reshape(B, N, self.D, self.D) |
| M = M.reshape(B, N, self.matrix_v, self.D) |
| S_coord = S |
| for layer in self.cross_attn: |
| S_coord = layer(S_coord) |
| return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M} |
|
|
| def decode_patches(self, U, S, Vt): |
| B, N, V, D = U.shape |
| U_flat = U.reshape(B * N, V, D) |
| S_flat = S.reshape(B * N, D) |
| Vt_flat = Vt.reshape(B * N, D, D) |
| M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) |
| h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1))) |
| for block in self.dec_blocks: |
| h = h + block(h) |
| return self.dec_out(h).reshape(B, N, -1) |
|
|
| def forward(self, images): |
| patches, gh, gw = extract_patches(images, self.patch_size) |
| svd = self.encode_patches(patches) |
| decoded = self.decode_patches(svd['U'], svd['S'], svd['Vt']) |
| recon = stitch_patches(decoded, gh, gw, self.patch_size) |
| recon = self.boundary_smooth(recon) |
| return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw} |
|
|
| @staticmethod |
| def effective_rank(S): |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| return (-(p * p.log()).sum(-1)).exp() |
|
|
|
|
| |
|
|
| def train(): |
| |
| V, D, patch_size = 256, 16, 16 |
| hidden, depth = 768, 4 |
| n_cross_layers = 2 |
| batch_size = 128 |
| lr = 1e-4 |
| epochs = 200 |
| target_cv = 0.125 |
| cv_weight, boost, sigma = 0.3, 0.5, 0.15 |
| img_size = 128 |
|
|
| save_dir = '/content/checkpoints' |
| save_every = 10 |
| report_every = 5000 |
| hf_repo = 'AbstractPhil/geolip-SVAE' |
| hf_version = 'v16_johanna_omega' |
| tb_dir = '/content/runs' |
|
|
| |
| pretrained_repo = 'AbstractPhil/geolip-SVAE' |
| pretrained_file = 'v14_noise/checkpoints/epoch_0200.pt' |
|
|
| os.makedirs(save_dir, exist_ok=True) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| from torch.utils.tensorboard import SummaryWriter |
| run_name = f"johanna_omega_V{V}_D{D}_h{hidden}_d{depth}" |
| tb_path = os.path.join(tb_dir, run_name) |
| writer = SummaryWriter(tb_path) |
| print(f" TensorBoard: {tb_path}") |
|
|
| |
| hf_enabled = False |
| try: |
| from huggingface_hub import HfApi, hf_hub_download |
| api = HfApi() |
| api.whoami() |
| hf_enabled = True |
| hf_prefix = f"{hf_version}/checkpoints" |
| print(f" HuggingFace: {hf_repo}/{hf_prefix}") |
| except Exception as e: |
| print(f" HuggingFace: disabled ({e})") |
|
|
| def upload_to_hf(local_path, remote_name): |
| if not hf_enabled: |
| return |
| try: |
| api.upload_file(path_or_fileobj=local_path, |
| path_in_repo=f"{hf_prefix}/{remote_name}", |
| repo_id=hf_repo, repo_type="model") |
| print(f" βοΈ Uploaded: {hf_repo}/{hf_prefix}/{remote_name}") |
| except Exception as e: |
| print(f" β οΈ HF upload failed: {e}") |
|
|
| |
| print(f"\n Loading pretrained: {pretrained_repo}/{pretrained_file}") |
| ckpt_path = hf_hub_download(repo_id=pretrained_repo, filename=pretrained_file, |
| repo_type="model") |
| ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| print(f" Pretrained epoch: {ckpt['epoch']}, MSE: {ckpt['test_mse']:.6f}") |
| print(f" Pretrained config: {ckpt['config']}") |
|
|
| |
| model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size, |
| hidden=hidden, depth=depth, |
| n_cross_layers=n_cross_layers).to(device) |
| model.load_state_dict(ckpt['model_state_dict'], strict=True) |
| print(f" Loaded {sum(p.numel() for p in model.parameters()):,} parameters") |
|
|
| |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| |
| train_ds = OmegaNoiseDataset(size=1280000, img_size=img_size) |
| val_ds = OmegaNoiseDataset(size=10000, img_size=img_size) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=4, pin_memory=True, drop_last=True) |
| test_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=batch_size, shuffle=False, |
| num_workers=4, pin_memory=True) |
|
|
| n_patches = (img_size // patch_size) ** 2 |
| batches_per_epoch = len(train_loader) |
| total_params = sum(p.numel() for p in model.parameters()) |
|
|
| print(f"\n JOHANNA-128 OMEGA CONTINUATION") |
| print(f" Pretrained on: Gaussian N(0,1), 200 epochs, MSE=0.059") |
| print(f" Now training on: 16 noise types at {img_size}Γ{img_size}") |
| print(f" {n_patches} patches, ({V},{D}), hidden={hidden}, depth={depth}") |
| print(f" Params: {total_params:,}, batch={batch_size}") |
| print(f" Batches/epoch: {batches_per_epoch}, lr={lr}") |
| print(f" Report every {report_every} batches") |
| print("=" * 100) |
| print(f" {'ep':>3} {'batch':>8} | {'loss':>7} {'recon':>7} | " |
| f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " |
| f"{'row_cv':>7} {'prox':>5} {'rw':>5} | " |
| f"{'S_delta':>7}") |
| print("-" * 100) |
|
|
| best_recon = float('inf') |
| global_batch = 0 |
|
|
| def save_checkpoint(path, epoch, test_mse, extra=None, upload=True): |
| ckpt_out = { |
| 'epoch': epoch, 'test_mse': test_mse, |
| 'global_batch': global_batch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': opt.state_dict(), |
| 'scheduler_state_dict': sched.state_dict(), |
| 'config': { |
| 'V': V, 'D': D, 'patch_size': patch_size, |
| 'hidden': hidden, 'depth': depth, |
| 'n_cross_layers': n_cross_layers, |
| 'target_cv': target_cv, |
| 'dataset': 'omega_noise_16types_128', |
| 'pretrained_from': 'v14_noise/epoch_0200.pt', |
| 'img_size': img_size, 'lr': lr, |
| }, |
| } |
| if extra: |
| ckpt_out.update(extra) |
| torch.save(ckpt_out, path) |
| size_mb = os.path.getsize(path) / (1024 * 1024) |
| print(f" πΎ Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})") |
| if upload: |
| upload_to_hf(path, os.path.basename(path)) |
|
|
| |
| for epoch in range(1, epochs + 1): |
| model.train() |
| total_loss, total_recon, n = 0, 0, 0 |
| last_cv, last_prox, recon_w = target_cv, 1.0, 1.0 + boost |
| t0 = time.time() |
|
|
| pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}", |
| bar_format='{l_bar}{bar:20}{r_bar}') |
| for batch_idx, (images, noise_types) in enumerate(pbar): |
| images = images.to(device) |
| opt.zero_grad() |
| out = model(images) |
| recon_loss = F.mse_loss(out['recon'], images) |
|
|
| with torch.no_grad(): |
| if batch_idx % 50 == 0: |
| current_cv = cv_of(out['svd']['M'][0, 0]) |
| if current_cv > 0: |
| last_cv = current_cv |
| delta = last_cv - target_cv |
| last_prox = math.exp(-delta**2 / (2 * sigma**2)) |
|
|
| recon_w = 1.0 + boost * last_prox |
| cv_pen = cv_weight * (1.0 - last_prox) |
| cv_l = (last_cv - target_cv) ** 2 |
| loss = recon_w * recon_loss + cv_pen * cv_l |
| loss.backward() |
|
|
| torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5) |
| opt.step() |
|
|
| total_loss += loss.item() * len(images) |
| total_recon += recon_loss.item() * len(images) |
| n += len(images) |
| global_batch += 1 |
|
|
| pbar.set_postfix_str( |
| f"loss={recon_loss.item():.4f} cv={last_cv:.3f} prox={last_prox:.2f}", |
| refresh=False) |
|
|
| |
| if global_batch % report_every == 0: |
| model.eval() |
| with torch.no_grad(): |
| test_imgs, _ = next(iter(test_loader)) |
| test_imgs = test_imgs.to(device) |
| test_out = model(test_imgs) |
| test_mse = F.mse_loss(test_out['recon'], test_imgs).item() |
| S_mean = test_out['svd']['S'].mean(dim=(0, 1)) |
| S_orig = test_out['svd']['S_orig'].mean(dim=(0, 1)) |
| erank = model.effective_rank( |
| test_out['svd']['S'].reshape(-1, D)).mean().item() |
| s_delta = (S_mean - S_orig).abs().mean().item() |
| ratio = (S_mean[0] / (S_mean[-1] + 1e-8)).item() |
|
|
| writer.add_scalar('train/recon', total_recon / n, global_batch) |
| writer.add_scalar('test/recon_mse', test_mse, global_batch) |
| writer.add_scalar('geo/row_cv', last_cv, global_batch) |
| writer.add_scalar('geo/ratio', ratio, global_batch) |
| writer.add_scalar('geo/erank', erank, global_batch) |
| writer.add_scalar('geo/S0', S_mean[0].item(), global_batch) |
| writer.add_scalar('cross_attn/s_delta', s_delta, global_batch) |
|
|
| print(f"\n {epoch:3d} {global_batch:8d} | " |
| f"{total_loss/n:7.4f} {total_recon/n:7.4f} | " |
| f"{S_mean[0]:6.3f} {S_mean[-1]:6.3f} {ratio:5.2f} {erank:5.2f} | " |
| f"{last_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f} | " |
| f"{s_delta:7.5f}") |
|
|
| if test_mse < best_recon: |
| best_recon = test_mse |
| save_checkpoint(os.path.join(save_dir, 'best.pt'), |
| epoch, test_mse, upload=False) |
| model.train() |
|
|
| pbar.close() |
| sched.step() |
| epoch_time = time.time() - t0 |
|
|
| writer.add_scalar('train/epoch_time', epoch_time, epoch) |
|
|
| |
| model.eval() |
| test_recon_total, test_n = 0, 0 |
| with torch.no_grad(): |
| for test_imgs, _ in test_loader: |
| test_imgs = test_imgs.to(device) |
| out = model(test_imgs) |
| test_recon_total += F.mse_loss(out['recon'], test_imgs).item() * len(test_imgs) |
| test_n += len(test_imgs) |
| epoch_test_mse = test_recon_total / test_n |
|
|
| print(f" Epoch {epoch} done: {epoch_time:.1f}s, test_mse={epoch_test_mse:.6f}, " |
| f"best={best_recon:.6f}") |
|
|
| if epoch_test_mse < best_recon: |
| best_recon = epoch_test_mse |
| save_checkpoint(os.path.join(save_dir, 'best.pt'), |
| epoch, epoch_test_mse, upload=False) |
|
|
| if epoch % save_every == 0: |
| save_checkpoint(os.path.join(save_dir, f'epoch_{epoch:04d}.pt'), |
| epoch, epoch_test_mse) |
| best_path = os.path.join(save_dir, 'best.pt') |
| if os.path.exists(best_path): |
| upload_to_hf(best_path, 'best.pt') |
| writer.flush() |
| if hf_enabled: |
| try: |
| api.upload_folder(folder_path=tb_path, |
| path_in_repo=f"{hf_version}/tensorboard/{run_name}", |
| repo_id=hf_repo, repo_type="model") |
| print(f" βοΈ TB synced") |
| except: |
| pass |
|
|
| writer.close() |
| print(f"\n JOHANNA-128 OMEGA TRAINING COMPLETE") |
| print(f" Best MSE: {best_recon:.6f}") |
| print(f" Checkpoints: {save_dir}/") |
|
|
|
|
| if __name__ == "__main__": |
| torch.set_float32_matmul_precision('high') |
| train() |