""" Johanna-Tiny Curriculum — Tiered Noise Introduction ===================================================== Start with Gaussian. Introduce harder noise types only when the current tier converges. Track per-type MSE to identify which distributions break the geometry. Tiers: 0: Gaussian (foundation) 1: + Pink, Brown, Block-structured, Gradient (correlated) 2: + Uniform, Scaled uniform, Checkerboard, Mixed (bounded) 3: + Poisson, Exponential, Laplace, Sparse (adversarial) 4: + Cauchy, Salt-and-pepper, Structural inconsist. (hostile) Promotion: when tier MSE improvement < 1% over 10 epochs, unlock next tier. """ import os import torch import torch.nn as nn import torch.nn.functional as F import math import time import numpy as np from tqdm import tqdm try: from google.colab import userdata os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') from huggingface_hub import login login(token=os.environ["HF_TOKEN"]) except Exception: pass # ── SVD Backend ────────────────────────────────────────────────── try: from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N HAS_FL = True except ImportError: HAS_FL = False def gram_eigh_svd_fp64(A): orig_dtype = A.dtype with torch.amp.autocast('cuda', enabled=False): A_d = A.double() G = torch.bmm(A_d.transpose(1, 2), A_d) G.diagonal(dim1=-2, dim2=-1).add_(1e-12) eigenvalues, V = torch.linalg.eigh(G) eigenvalues = eigenvalues.flip(-1) V = V.flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) def svd_fp64(A): B, M, N = A.shape if HAS_FL and N <= _FL_MAX_N and A.is_cuda: orig_dtype = A.dtype with torch.amp.autocast('cuda', enabled=False): A_d = A.double() G = torch.bmm(A_d.transpose(1, 2), A_d) eigenvalues, V = FLEigh()(G.float()) eigenvalues = eigenvalues.double().flip(-1) V = V.double().flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype) else: return gram_eigh_svd_fp64(A) def cayley_menger_vol2(points): B, N, D = points.shape pts = points.double() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) def cv_of(emb, n_samples=200): if emb.dim() != 2 or emb.shape[0] < 5: return 0.0 N, D = emb.shape pool = min(N, 512) indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)]) vol2 = cayley_menger_vol2(emb[:pool][indices]) valid = vol2 > 1e-20 if valid.sum() < 10: return 0.0 vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # ── Noise Type Registry ───────────────────────────────────────── NOISE_NAMES = { 0: 'gaussian', 1: 'uniform', 2: 'uniform_scaled', 3: 'poisson', 4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse', 8: 'block', 9: 'gradient', 10: 'checkerboard', 11: 'mixed', 12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace', } TIERS = { 0: [0], # Gaussian (foundation) 1: [4, 5, 8, 9], # Pink, Brown, Block, Gradient (correlated) 2: [1, 2, 10, 11], # Uniform, Scaled, Checkerboard, Mixed (bounded) 3: [3, 14, 15, 7], # Poisson, Exponential, Laplace, Sparse (adversarial) 4: [13, 6, 12], # Cauchy, Salt-pepper, Structural (hostile) } # ── Curriculum Noise Dataset ───────────────────────────────────── class CurriculumNoiseDataset(torch.utils.data.Dataset): """Noise dataset with tier-based type activation. Only generates noise types that are currently unlocked. Types are activated by tier — call unlock_tier(n) to enable. """ def __init__(self, size=500000, img_size=64, seed_rotate_every=1000): self.size = size self.img_size = img_size self.seed_rotate_every = seed_rotate_every self._rng = np.random.RandomState(42) self._call_count = 0 self.active_types = list(TIERS[0]) # start with Gaussian only self.current_tier = 0 def unlock_tier(self, tier): """Unlock a tier of noise types.""" if tier in TIERS: for t in TIERS[tier]: if t not in self.active_types: self.active_types.append(t) self.current_tier = tier def __len__(self): return self.size def _rotate_seed(self): self._call_count += 1 if self._call_count % self.seed_rotate_every == 0: new_seed = int.from_bytes(os.urandom(4), 'big') self._rng = np.random.RandomState(new_seed) torch.manual_seed(new_seed) def _pink_noise(self, shape): white = torch.randn(shape) S = torch.fft.rfft2(white) h, w = shape[-2], shape[-1] fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1) fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1) f = torch.sqrt(fx**2 + fy**2).clamp(min=1e-8) return torch.fft.irfft2(S / f, s=(h, w)) def _brown_noise(self, shape): white = torch.randn(shape) S = torch.fft.rfft2(white) h, w = shape[-2], shape[-1] fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1) fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1) f = (fx**2 + fy**2).clamp(min=1e-8) return torch.fft.irfft2(S / f, s=(h, w)) def _generate(self, noise_type): s = self.img_size if noise_type == 0: return torch.randn(3, s, s) elif noise_type == 1: return torch.rand(3, s, s) * 2 - 1 elif noise_type == 2: return (torch.rand(3, s, s) - 0.5) * 4 elif noise_type == 3: lam = self._rng.uniform(0.5, 20.0) return torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0 elif noise_type == 4: img = self._pink_noise((3, s, s)); return img / (img.std() + 1e-8) elif noise_type == 5: img = self._brown_noise((3, s, s)); return img / (img.std() + 1e-8) elif noise_type == 6: img = torch.where(torch.rand(3,s,s)>0.5, torch.ones(3,s,s)*2, -torch.ones(3,s,s)*2) return img + torch.randn(3, s, s) * 0.1 elif noise_type == 7: return torch.randn(3,s,s) * (torch.rand(3,s,s) > 0.9).float() * 3 elif noise_type == 8: b = self._rng.randint(2, 16) small = torch.randn(3, s//b+1, s//b+1) return F.interpolate(small.unsqueeze(0), size=s, mode='nearest').squeeze(0) elif noise_type == 9: gy = torch.linspace(-2,2,s).unsqueeze(1).expand(s,s) gx = torch.linspace(-2,2,s).unsqueeze(0).expand(s,s) a = self._rng.uniform(0, 2*math.pi) return (math.cos(a)*gx + math.sin(a)*gy).unsqueeze(0).expand(3,-1,-1) + torch.randn(3,s,s)*0.5 elif noise_type == 10: cs = self._rng.randint(2, 16) cy = torch.arange(s)//cs; cx = torch.arange(s)//cs checker = ((cy.unsqueeze(1)+cx.unsqueeze(0))%2).float()*2-1 return checker.unsqueeze(0).expand(3,-1,-1) + torch.randn(3,s,s)*0.3 elif noise_type == 11: alpha = self._rng.uniform(0.2, 0.8) return alpha*torch.randn(3,s,s) + (1-alpha)*(torch.rand(3,s,s)*2-1) elif noise_type == 12: img = torch.zeros(3,s,s); h2 = s//2 img[:,:h2,:h2] = torch.randn(3,h2,h2) img[:,:h2,h2:] = torch.rand(3,h2,h2)*2-1 img[:,h2:,:h2] = self._pink_noise((3,h2,h2))/2 img[:,h2:,h2:] = torch.where(torch.rand(3,h2,h2)>0.5, torch.ones(3,h2,h2), -torch.ones(3,h2,h2)) return img elif noise_type == 13: return torch.tan(math.pi*(torch.rand(3,s,s)-0.5)).clamp(-3,3) elif noise_type == 14: return torch.empty(3,s,s).exponential_(1.0) - 1.0 elif noise_type == 15: u = torch.rand(3,s,s)-0.5; return -torch.sign(u)*torch.log1p(-2*u.abs()) return torch.randn(3, s, s) def __getitem__(self, idx): self._rotate_seed() noise_type = self.active_types[idx % len(self.active_types)] img = self._generate(noise_type).clamp(-4, 4) return img.float(), noise_type # ── Model (identical to proven architecture) ───────────────────── def extract_patches(images, patch_size=16): B, C, H, W = images.shape gh, gw = H // patch_size, W // patch_size p = images.reshape(B, C, gh, patch_size, gw, patch_size) return p.permute(0,2,4,1,3,5).reshape(B, gh*gw, C*patch_size*patch_size), gh, gw def stitch_patches(patches, gh, gw, patch_size=16): B = patches.shape[0] p = patches.reshape(B, gh, gw, 3, patch_size, patch_size) return p.permute(0,3,1,4,2,5).reshape(B, 3, gh*patch_size, gw*patch_size) class BoundarySmooth(nn.Module): def __init__(self, channels=3, mid=16): super().__init__() self.net = nn.Sequential(nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(), nn.Conv2d(mid, channels, 3, padding=1)) nn.init.zeros_(self.net[-1].weight); nn.init.zeros_(self.net[-1].bias) def forward(self, x): return x + self.net(x) class SpectralCrossAttention(nn.Module): def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0): super().__init__() self.n_heads = n_heads; self.head_dim = D // n_heads self.max_alpha = max_alpha; assert D % n_heads == 0 self.qkv = nn.Linear(D, 3*D); self.out_proj = nn.Linear(D, D) self.norm = nn.LayerNorm(D); self.scale = self.head_dim**-0.5 self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) @property def alpha(self): return self.max_alpha * torch.sigmoid(self.alpha_logits) def forward(self, S): B, N, D = S.shape; S_n = self.norm(S) qkv = self.qkv(S_n).reshape(B,N,3,self.n_heads,self.head_dim).permute(2,0,3,1,4) q, k, v = qkv[0], qkv[1], qkv[2] out = (((q @ k.transpose(-2,-1))*self.scale).softmax(-1) @ v).transpose(1,2).reshape(B,N,D) return S * (1.0 + self.alpha.unsqueeze(0).unsqueeze(0) * torch.tanh(self.out_proj(out))) class PatchSVAE(nn.Module): def __init__(self, matrix_v=256, D=16, patch_size=16, hidden=768, depth=4, n_cross_layers=2): super().__init__() self.matrix_v, self.D, self.patch_size = matrix_v, D, patch_size self.patch_dim = 3*patch_size*patch_size; self.mat_dim = matrix_v*D self.enc_in = nn.Linear(self.patch_dim, hidden) self.enc_blocks = nn.ModuleList([nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth)]) self.enc_out = nn.Linear(hidden, self.mat_dim) self.dec_in = nn.Linear(self.mat_dim, hidden) self.dec_blocks = nn.ModuleList([nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth)]) self.dec_out = nn.Linear(hidden, self.patch_dim) nn.init.orthogonal_(self.enc_out.weight) self.cross_attn = nn.ModuleList([ SpectralCrossAttention(D, n_heads=min(4,D)) for _ in range(n_cross_layers)]) self.boundary_smooth = BoundarySmooth(channels=3, mid=16) def encode_patches(self, patches): B, N, _ = patches.shape h = F.gelu(self.enc_in(patches.reshape(B*N,-1))) for block in self.enc_blocks: h = h + block(h) M = F.normalize(self.enc_out(h).reshape(B*N, self.matrix_v, self.D), dim=-1) U, S, Vt = svd_fp64(M) U = U.reshape(B,N,self.matrix_v,self.D); S = S.reshape(B,N,self.D) Vt = Vt.reshape(B,N,self.D,self.D); M = M.reshape(B,N,self.matrix_v,self.D) S_c = S for layer in self.cross_attn: S_c = layer(S_c) return {'U':U, 'S_orig':S, 'S':S_c, 'Vt':Vt, 'M':M} def decode_patches(self, U, S, Vt): B, N, V, D = U.shape M_hat = torch.bmm(U.reshape(B*N,V,D)*S.reshape(B*N,D).unsqueeze(1), Vt.reshape(B*N,D,D)) h = F.gelu(self.dec_in(M_hat.reshape(B*N,-1))) for block in self.dec_blocks: h = h + block(h) return self.dec_out(h).reshape(B, N, -1) def forward(self, images): patches, gh, gw = extract_patches(images, self.patch_size) svd = self.encode_patches(patches) recon = stitch_patches(self.decode_patches(svd['U'], svd['S'], svd['Vt']), gh, gw, self.patch_size) return {'recon': self.boundary_smooth(recon), 'svd': svd} @staticmethod def effective_rank(S): p = S / (S.sum(-1, keepdim=True)+1e-8); p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() # ── Per-Type MSE Evaluation ────────────────────────────────────── def eval_per_type(model, dataset, device, n_per_type=64): """Evaluate MSE for each active noise type independently.""" model.eval() type_mse = {} with torch.no_grad(): for t in dataset.active_types: imgs = torch.stack([dataset._generate(t).clamp(-4, 4) for _ in range(n_per_type)]).to(device) out = model(imgs) type_mse[t] = F.mse_loss(out['recon'], imgs).item() return type_mse # ── Training ───────────────────────────────────────────────────── def train(): V, D, patch_size = 256, 16, 16 hidden, depth = 768, 4 n_cross_layers = 2 batch_size = 512 lr = 3e-4 epochs = 300 target_cv = 0.125 cv_weight, boost, sigma = 0.3, 0.5, 0.15 img_size = 64 # Curriculum config promote_patience = 10 # epochs of <1% improvement before promoting promote_threshold = 0.01 # relative improvement threshold save_dir = '/content/checkpoints' save_every = 25 hf_repo = 'AbstractPhil/geolip-SVAE' hf_version = 'v18_johanna_curriculum' tb_dir = '/content/runs' os.makedirs(save_dir, exist_ok=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') from torch.utils.tensorboard import SummaryWriter run_name = f"johanna_tiny_curriculum_64x64_h{hidden}_d{depth}_lr{lr}" tb_path = os.path.join(tb_dir, run_name) writer = SummaryWriter(tb_path) print(f" TensorBoard: {tb_path}") hf_enabled = False try: from huggingface_hub import HfApi api = HfApi(); api.whoami(); hf_enabled = True hf_prefix = f"{hf_version}/checkpoints" print(f" HuggingFace: {hf_repo}/{hf_prefix}") except Exception as e: print(f" HuggingFace: disabled ({e})") def upload_to_hf(local_path, remote_name): if not hf_enabled: return try: api.upload_file(path_or_fileobj=local_path, path_in_repo=f"{hf_prefix}/{remote_name}", repo_id=hf_repo, repo_type="model") print(f" ☁️ Uploaded: {hf_repo}/{hf_prefix}/{remote_name}") except Exception as e: print(f" ⚠️ HF upload: {e}") # ── Data: Curriculum noise ── train_ds = CurriculumNoiseDataset(size=500000, img_size=img_size) val_ds = CurriculumNoiseDataset(size=10000, img_size=img_size) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) test_loader = torch.utils.data.DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) # ── Model: fresh init ── model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size, hidden=hidden, depth=depth, n_cross_layers=n_cross_layers).to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) total_params = sum(p.numel() for p in model.parameters()) print(f"\n JOHANNA-TINY CURRICULUM TRAINER") print(f" {img_size}×{img_size}, 16 patches, ({V},{D}), {total_params:,} params") print(f" Batch={batch_size}, lr={lr}, epochs={epochs}") print(f" Tiers: {len(TIERS)} tiers, promote after {promote_patience} epochs of <{promote_threshold*100:.0f}% improvement") for tier_id, types in sorted(TIERS.items()): names = [NOISE_NAMES[t] for t in types] print(f" Tier {tier_id}: {', '.join(names)}") print("=" * 110) print(f" {'ep':>3} {'tier':>4} {'types':>5} | {'loss':>7} {'recon':>7} | " f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | " f"{'row_cv':>7} {'prox':>5} | {'per-type MSE':>40}") print("-" * 110) best_recon = float('inf') tier_best_mse = float('inf') stale_epochs = 0 def save_checkpoint(path, epoch, test_mse, extra=None, upload=True): ckpt = { 'epoch': epoch, 'test_mse': test_mse, 'current_tier': train_ds.current_tier, 'active_types': train_ds.active_types, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': opt.state_dict(), 'scheduler_state_dict': sched.state_dict(), 'config': { 'V': V, 'D': D, 'patch_size': patch_size, 'hidden': hidden, 'depth': depth, 'n_cross_layers': n_cross_layers, 'target_cv': target_cv, 'dataset': 'curriculum_noise', 'img_size': img_size, 'lr': lr, }, } if extra: ckpt.update(extra) torch.save(ckpt, path) size_mb = os.path.getsize(path) / (1024 * 1024) print(f" 💾 Saved: {path} ({size_mb:.1f}MB, ep{epoch}, tier{train_ds.current_tier}, MSE={test_mse:.6f})") if upload: upload_to_hf(path, os.path.basename(path)) for epoch in range(1, epochs + 1): model.train() total_loss, total_recon, n = 0, 0, 0 last_cv, last_prox = target_cv, 1.0 t0 = time.time() pbar = tqdm(train_loader, desc=f"Ep {epoch} T{train_ds.current_tier}({len(train_ds.active_types)})", bar_format='{l_bar}{bar:20}{r_bar}') for batch_idx, (images, noise_types) in enumerate(pbar): images = images.to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) with torch.no_grad(): if batch_idx % 50 == 0: current_cv = cv_of(out['svd']['M'][0, 0]) if current_cv > 0: last_cv = current_cv delta = last_cv - target_cv last_prox = math.exp(-delta**2 / (2*sigma**2)) recon_w = 1.0 + boost * last_prox cv_pen = cv_weight * (1.0 - last_prox) loss = recon_w * recon_loss + cv_pen * (last_cv - target_cv)**2 loss.backward() torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5) opt.step() total_loss += loss.item() * len(images) total_recon += recon_loss.item() * len(images) n += len(images) pbar.set_postfix_str(f"mse={recon_loss.item():.4f} cv={last_cv:.3f} prox={last_prox:.2f}") pbar.close() sched.step() epoch_time = time.time() - t0 # ── Evaluation: overall + per-type ── model.eval() test_mse_total, test_n = 0, 0 with torch.no_grad(): for imgs, _ in test_loader: imgs = imgs.to(device) out = model(imgs) test_mse_total += F.mse_loss(out['recon'], imgs).item() * len(imgs) test_n += len(imgs) test_mse = test_mse_total / test_n # Per-type MSE type_mse = eval_per_type(model, train_ds, device, n_per_type=64) type_str = " ".join([f"{NOISE_NAMES[t][:4]}={v:.3f}" for t, v in sorted(type_mse.items())]) # Geometry with torch.no_grad(): sample, _ = next(iter(test_loader)) sample = sample[:64].to(device) out = model(sample) S_mean = out['svd']['S'].mean(dim=(0,1)) ratio = (S_mean[0] / (S_mean[-1]+1e-8)).item() erank = model.effective_rank(out['svd']['S'].reshape(-1, D)).mean().item() # TB logging writer.add_scalar('train/recon', total_recon/n, epoch) writer.add_scalar('test/mse', test_mse, epoch) writer.add_scalar('curriculum/tier', train_ds.current_tier, epoch) writer.add_scalar('curriculum/n_types', len(train_ds.active_types), epoch) writer.add_scalar('geo/cv', last_cv, epoch) writer.add_scalar('geo/S0', S_mean[0].item(), epoch) writer.add_scalar('geo/ratio', ratio, epoch) for t, mse in type_mse.items(): writer.add_scalar(f'per_type/{NOISE_NAMES[t]}', mse, epoch) print(f" {epoch:3d} T{train_ds.current_tier:>2} {len(train_ds.active_types):>3}t | " f"{total_loss/n:7.4f} {total_recon/n:7.4f} | " f"{S_mean[0]:6.3f} {S_mean[-1]:6.3f} {ratio:5.2f} {erank:5.2f} | " f"{last_cv:7.4f} {last_prox:5.3f} | {type_str}") # ── Tier promotion logic ── improvement = (tier_best_mse - test_mse) / (tier_best_mse + 1e-8) if test_mse < tier_best_mse: tier_best_mse = test_mse if improvement < promote_threshold: stale_epochs += 1 else: stale_epochs = 0 if stale_epochs >= promote_patience and train_ds.current_tier < max(TIERS.keys()): next_tier = train_ds.current_tier + 1 train_ds.unlock_tier(next_tier) val_ds.unlock_tier(next_tier) new_names = [NOISE_NAMES[t] for t in TIERS[next_tier]] print(f"\n ★ PROMOTED TO TIER {next_tier}: +{', '.join(new_names)}") print(f" Active types: {[NOISE_NAMES[t] for t in train_ds.active_types]}") print(f" Tier MSE was: {tier_best_mse:.6f}\n") tier_best_mse = test_mse # reset for new tier stale_epochs = 0 # Save promotion checkpoint save_checkpoint(os.path.join(save_dir, f'tier{next_tier}_start.pt'), epoch, test_mse, upload=True) # ── Checkpoints ── if test_mse < best_recon: best_recon = test_mse save_checkpoint(os.path.join(save_dir, 'best.pt'), epoch, test_mse, upload=False) if epoch % save_every == 0: save_checkpoint(os.path.join(save_dir, f'epoch_{epoch:04d}.pt'), epoch, test_mse) best_path = os.path.join(save_dir, 'best.pt') if os.path.exists(best_path): upload_to_hf(best_path, 'best.pt') writer.flush() if hf_enabled: try: api.upload_folder(folder_path=tb_path, path_in_repo=f"{hf_version}/tensorboard/{run_name}", repo_id=hf_repo, repo_type="model") print(f" ☁️ TB synced") except: pass writer.close() print(f"\n CURRICULUM TRAINING COMPLETE") print(f" Final tier: {train_ds.current_tier}") print(f" Active types: {[NOISE_NAMES[t] for t in train_ds.active_types]}") print(f" Best MSE: {best_recon:.6f}") if __name__ == "__main__": torch.set_float32_matmul_precision('high') train()