AbstractPhil commited on
Commit
763a738
Β·
verified Β·
1 Parent(s): f0aa560

Create svae_johanna_noise_trainer.py

Browse files
Files changed (1) hide show
  1. svae_johanna_noise_trainer.py +620 -0
svae_johanna_noise_trainer.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Johanna-128 Omega β€” Continue Gaussian-pretrained model on 16 noise types
3
+ ==========================================================================
4
+ Loads the Gaussian-trained checkpoint (ep200, MSE=0.059) and expands
5
+ the signal vocabulary to all 16 noise distributions at 128Γ—128.
6
+
7
+ The Gaussian knowledge is the foundation β€” the MLP already knows how to
8
+ invert the geometric projection for one distribution. Now we teach it
9
+ the other 15 without destroying what it learned.
10
+
11
+ Strategy:
12
+ - Moderate lr (1e-4): fast enough to learn new distributions,
13
+ slow enough to preserve Gaussian knowledge
14
+ - Gaussian is 1 of 16 types, so it stays in the training mix
15
+ - Same architecture: V=256, D=16, hidden=768, depth=4, 17M params
16
+ - Batch=128 to stay under cusolver limit (128 Γ— 64 patches = 8192 calls)
17
+ """
18
+
19
+ import os
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torchvision
24
+ import torchvision.transforms as T
25
+ import math
26
+ import time
27
+ import numpy as np
28
+ from tqdm import tqdm
29
+
30
+ # ── HuggingFace auth from Colab secrets ──
31
+ try:
32
+ from google.colab import userdata
33
+ os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
34
+ from huggingface_hub import login
35
+ login(token=os.environ["HF_TOKEN"])
36
+ except Exception:
37
+ pass
38
+
39
+ # ── SVD Backend ──────────────────────────────────────────────────
40
+
41
+ try:
42
+ from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N
43
+ HAS_FL = True
44
+ except ImportError:
45
+ HAS_FL = False
46
+
47
+
48
+ def gram_eigh_svd_fp64(A):
49
+ orig_dtype = A.dtype
50
+ with torch.amp.autocast('cuda', enabled=False):
51
+ A_d = A.double()
52
+ G = torch.bmm(A_d.transpose(1, 2), A_d)
53
+ G.diagonal(dim1=-2, dim2=-1).add_(1e-12)
54
+ eigenvalues, V = torch.linalg.eigh(G)
55
+ eigenvalues = eigenvalues.flip(-1)
56
+ V = V.flip(-1)
57
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
58
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
59
+ Vh = V.transpose(-2, -1).contiguous()
60
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
61
+
62
+
63
+ def svd_fp64(A):
64
+ B, M, N = A.shape
65
+ if HAS_FL and N <= _FL_MAX_N and A.is_cuda:
66
+ orig_dtype = A.dtype
67
+ with torch.amp.autocast('cuda', enabled=False):
68
+ A_d = A.double()
69
+ G = torch.bmm(A_d.transpose(1, 2), A_d)
70
+ eigenvalues, V = FLEigh()(G.float())
71
+ eigenvalues = eigenvalues.double().flip(-1)
72
+ V = V.double().flip(-1)
73
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
74
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
75
+ Vh = V.transpose(-2, -1).contiguous()
76
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
77
+ else:
78
+ return gram_eigh_svd_fp64(A)
79
+
80
+
81
+ # ── CV Monitoring ────────────────────────────────────────────────
82
+
83
+ def cayley_menger_vol2(points):
84
+ B, N, D = points.shape
85
+ pts = points.double()
86
+ gram = torch.bmm(pts, pts.transpose(1, 2))
87
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
88
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
89
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
90
+ cm[:, 0, 1:] = 1.0
91
+ cm[:, 1:, 0] = 1.0
92
+ cm[:, 1:, 1:] = d2
93
+ k = N - 1
94
+ sign = (-1.0) ** (k + 1)
95
+ fact = math.factorial(k)
96
+ return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
97
+
98
+
99
+ def cv_of(emb, n_samples=200):
100
+ if emb.dim() != 2 or emb.shape[0] < 5:
101
+ return 0.0
102
+ N, D = emb.shape
103
+ pool = min(N, 512)
104
+ indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)])
105
+ vol2 = cayley_menger_vol2(emb[:pool][indices])
106
+ valid = vol2 > 1e-20
107
+ if valid.sum() < 10:
108
+ return 0.0
109
+ vols = vol2[valid].sqrt()
110
+ return (vols.std() / (vols.mean() + 1e-8)).item()
111
+
112
+
113
+ # ── Comprehensive Noise Dataset (128Γ—128) ────────────────────────
114
+
115
+ class OmegaNoiseDataset(torch.utils.data.Dataset):
116
+ """16 noise types at arbitrary resolution. Seed rotation."""
117
+
118
+ N_TYPES = 16
119
+
120
+ def __init__(self, size=1000000, img_size=128, seed_rotate_every=1000):
121
+ self.size = size
122
+ self.img_size = img_size
123
+ self.seed_rotate_every = seed_rotate_every
124
+ self._rng = np.random.RandomState(42)
125
+ self._call_count = 0
126
+
127
+ def __len__(self):
128
+ return self.size
129
+
130
+ def _rotate_seed(self):
131
+ self._call_count += 1
132
+ if self._call_count % self.seed_rotate_every == 0:
133
+ new_seed = int.from_bytes(os.urandom(4), 'big')
134
+ self._rng = np.random.RandomState(new_seed)
135
+ torch.manual_seed(new_seed)
136
+
137
+ def _pink_noise(self, shape):
138
+ white = torch.randn(shape)
139
+ S = torch.fft.rfft2(white)
140
+ h, w = shape[-2], shape[-1]
141
+ fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1)
142
+ fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1)
143
+ f = torch.sqrt(fx**2 + fy**2).clamp(min=1e-8)
144
+ S = S / f
145
+ return torch.fft.irfft2(S, s=(h, w))
146
+
147
+ def _brown_noise(self, shape):
148
+ white = torch.randn(shape)
149
+ S = torch.fft.rfft2(white)
150
+ h, w = shape[-2], shape[-1]
151
+ fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1)
152
+ fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1)
153
+ f = (fx**2 + fy**2).clamp(min=1e-8)
154
+ S = S / f
155
+ return torch.fft.irfft2(S, s=(h, w))
156
+
157
+ def __getitem__(self, idx):
158
+ self._rotate_seed()
159
+ s = self.img_size
160
+ noise_type = idx % self.N_TYPES
161
+
162
+ if noise_type == 0:
163
+ img = torch.randn(3, s, s)
164
+ elif noise_type == 1:
165
+ img = torch.rand(3, s, s) * 2 - 1
166
+ elif noise_type == 2:
167
+ img = (torch.rand(3, s, s) - 0.5) * 4
168
+ elif noise_type == 3:
169
+ lam = self._rng.uniform(0.5, 20.0)
170
+ img = torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0
171
+ elif noise_type == 4:
172
+ img = self._pink_noise((3, s, s))
173
+ img = img / (img.std() + 1e-8)
174
+ elif noise_type == 5:
175
+ img = self._brown_noise((3, s, s))
176
+ img = img / (img.std() + 1e-8)
177
+ elif noise_type == 6:
178
+ img = torch.where(torch.rand(3, s, s) > 0.5,
179
+ torch.ones(3, s, s) * 2, torch.ones(3, s, s) * -2)
180
+ img = img + torch.randn(3, s, s) * 0.1
181
+ elif noise_type == 7:
182
+ mask = torch.rand(3, s, s) > 0.9
183
+ img = torch.randn(3, s, s) * mask.float() * 3
184
+ elif noise_type == 8:
185
+ block = self._rng.randint(2, 16)
186
+ small = torch.randn(3, s // block + 1, s // block + 1)
187
+ img = F.interpolate(small.unsqueeze(0), size=s, mode='nearest').squeeze(0)
188
+ elif noise_type == 9:
189
+ gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s)
190
+ gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s)
191
+ angle = self._rng.uniform(0, 2 * math.pi)
192
+ grad = math.cos(angle) * gx + math.sin(angle) * gy
193
+ img = grad.unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5
194
+ elif noise_type == 10:
195
+ check_size = self._rng.randint(2, 16)
196
+ coords_y = torch.arange(s) // check_size
197
+ coords_x = torch.arange(s) // check_size
198
+ checker = ((coords_y.unsqueeze(1) + coords_x.unsqueeze(0)) % 2).float() * 2 - 1
199
+ img = checker.unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.3
200
+ elif noise_type == 11:
201
+ a = torch.randn(3, s, s)
202
+ b = torch.rand(3, s, s) * 2 - 1
203
+ alpha = self._rng.uniform(0.2, 0.8)
204
+ img = alpha * a + (1 - alpha) * b
205
+ elif noise_type == 12:
206
+ img = torch.zeros(3, s, s)
207
+ h2, w2 = s // 2, s // 2
208
+ img[:, :h2, :w2] = torch.randn(3, h2, w2)
209
+ img[:, :h2, w2:] = torch.rand(3, h2, w2) * 2 - 1
210
+ img[:, h2:, :w2] = self._pink_noise((3, h2, w2)) / 2
211
+ sp = torch.where(torch.rand(3, h2, w2) > 0.5,
212
+ torch.ones(3, h2, w2), -torch.ones(3, h2, w2))
213
+ img[:, h2:, w2:] = sp
214
+ elif noise_type == 13:
215
+ u = torch.rand(3, s, s)
216
+ img = torch.tan(math.pi * (u - 0.5))
217
+ img = img.clamp(-3, 3)
218
+ elif noise_type == 14:
219
+ img = torch.empty(3, s, s).exponential_(1.0) - 1.0
220
+ elif noise_type == 15:
221
+ u = torch.rand(3, s, s) - 0.5
222
+ img = -torch.sign(u) * torch.log1p(-2 * u.abs())
223
+
224
+ img = img.clamp(-4, 4)
225
+ return img.float(), noise_type
226
+
227
+
228
+ # ── Patch Utilities ──────────────────────────────────────────────
229
+
230
+ def extract_patches(images, patch_size=16):
231
+ B, C, H, W = images.shape
232
+ gh, gw = H // patch_size, W // patch_size
233
+ patches = images.reshape(B, C, gh, patch_size, gw, patch_size)
234
+ patches = patches.permute(0, 2, 4, 1, 3, 5)
235
+ return patches.reshape(B, gh * gw, C * patch_size * patch_size), gh, gw
236
+
237
+
238
+ def stitch_patches(patches, gh, gw, patch_size=16):
239
+ B = patches.shape[0]
240
+ patches = patches.reshape(B, gh, gw, 3, patch_size, patch_size)
241
+ patches = patches.permute(0, 3, 1, 4, 2, 5)
242
+ return patches.reshape(B, 3, gh * patch_size, gw * patch_size)
243
+
244
+
245
+ class BoundarySmooth(nn.Module):
246
+ def __init__(self, channels=3, mid=16):
247
+ super().__init__()
248
+ self.net = nn.Sequential(
249
+ nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(),
250
+ nn.Conv2d(mid, channels, 3, padding=1))
251
+ nn.init.zeros_(self.net[-1].weight)
252
+ nn.init.zeros_(self.net[-1].bias)
253
+ def forward(self, x):
254
+ return x + self.net(x)
255
+
256
+
257
+ class SpectralCrossAttention(nn.Module):
258
+ def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0):
259
+ super().__init__()
260
+ self.n_heads = n_heads
261
+ self.head_dim = D // n_heads
262
+ self.max_alpha = max_alpha
263
+ assert D % n_heads == 0
264
+ self.qkv = nn.Linear(D, 3 * D)
265
+ self.out_proj = nn.Linear(D, D)
266
+ self.norm = nn.LayerNorm(D)
267
+ self.scale = self.head_dim ** -0.5
268
+ self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
269
+
270
+ @property
271
+ def alpha(self):
272
+ return self.max_alpha * torch.sigmoid(self.alpha_logits)
273
+
274
+ def forward(self, S):
275
+ B, N, D = S.shape
276
+ S_normed = self.norm(S)
277
+ qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim)
278
+ qkv = qkv.permute(2, 0, 3, 1, 4)
279
+ q, k, v = qkv[0], qkv[1], qkv[2]
280
+ attn = (q @ k.transpose(-2, -1)) * self.scale
281
+ attn = attn.softmax(dim=-1)
282
+ out = (attn @ v).transpose(1, 2).reshape(B, N, D)
283
+ gate = torch.tanh(self.out_proj(out))
284
+ return S * (1.0 + self.alpha.unsqueeze(0).unsqueeze(0) * gate)
285
+
286
+
287
+ class PatchSVAE(nn.Module):
288
+ def __init__(self, matrix_v=256, D=16, patch_size=16, hidden=768,
289
+ depth=4, n_cross_layers=2):
290
+ super().__init__()
291
+ self.matrix_v = matrix_v
292
+ self.D = D
293
+ self.patch_size = patch_size
294
+ self.patch_dim = 3 * patch_size * patch_size
295
+ self.mat_dim = matrix_v * D
296
+
297
+ self.enc_in = nn.Linear(self.patch_dim, hidden)
298
+ self.enc_blocks = nn.ModuleList([
299
+ nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
300
+ nn.GELU(), nn.Linear(hidden, hidden))
301
+ for _ in range(depth)])
302
+ self.enc_out = nn.Linear(hidden, self.mat_dim)
303
+
304
+ self.dec_in = nn.Linear(self.mat_dim, hidden)
305
+ self.dec_blocks = nn.ModuleList([
306
+ nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
307
+ nn.GELU(), nn.Linear(hidden, hidden))
308
+ for _ in range(depth)])
309
+ self.dec_out = nn.Linear(hidden, self.patch_dim)
310
+
311
+ nn.init.orthogonal_(self.enc_out.weight)
312
+
313
+ self.cross_attn = nn.ModuleList([
314
+ SpectralCrossAttention(D, n_heads=min(4, D))
315
+ for _ in range(n_cross_layers)])
316
+ self.boundary_smooth = BoundarySmooth(channels=3, mid=16)
317
+
318
+ def encode_patches(self, patches):
319
+ B, N, _ = patches.shape
320
+ flat = patches.reshape(B * N, -1)
321
+ h = F.gelu(self.enc_in(flat))
322
+ for block in self.enc_blocks:
323
+ h = h + block(h)
324
+ M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D)
325
+ M = F.normalize(M, dim=-1)
326
+ U, S, Vt = svd_fp64(M)
327
+ U = U.reshape(B, N, self.matrix_v, self.D)
328
+ S = S.reshape(B, N, self.D)
329
+ Vt = Vt.reshape(B, N, self.D, self.D)
330
+ M = M.reshape(B, N, self.matrix_v, self.D)
331
+ S_coord = S
332
+ for layer in self.cross_attn:
333
+ S_coord = layer(S_coord)
334
+ return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M}
335
+
336
+ def decode_patches(self, U, S, Vt):
337
+ B, N, V, D = U.shape
338
+ U_flat = U.reshape(B * N, V, D)
339
+ S_flat = S.reshape(B * N, D)
340
+ Vt_flat = Vt.reshape(B * N, D, D)
341
+ M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
342
+ h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1)))
343
+ for block in self.dec_blocks:
344
+ h = h + block(h)
345
+ return self.dec_out(h).reshape(B, N, -1)
346
+
347
+ def forward(self, images):
348
+ patches, gh, gw = extract_patches(images, self.patch_size)
349
+ svd = self.encode_patches(patches)
350
+ decoded = self.decode_patches(svd['U'], svd['S'], svd['Vt'])
351
+ recon = stitch_patches(decoded, gh, gw, self.patch_size)
352
+ recon = self.boundary_smooth(recon)
353
+ return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw}
354
+
355
+ @staticmethod
356
+ def effective_rank(S):
357
+ p = S / (S.sum(-1, keepdim=True) + 1e-8)
358
+ p = p.clamp(min=1e-8)
359
+ return (-(p * p.log()).sum(-1)).exp()
360
+
361
+
362
+ # ── Training ─────────────────────────────────────────────────────
363
+
364
+ def train():
365
+ # ── Config ──
366
+ V, D, patch_size = 256, 16, 16
367
+ hidden, depth = 768, 4
368
+ n_cross_layers = 2
369
+ batch_size = 128 # 128 Γ— 64 patches = 8192 eigh calls (safe)
370
+ lr = 1e-4
371
+ epochs = 200
372
+ target_cv = 0.125
373
+ cv_weight, boost, sigma = 0.3, 0.5, 0.15
374
+ img_size = 128
375
+
376
+ save_dir = '/content/checkpoints'
377
+ save_every = 10
378
+ report_every = 5000
379
+ hf_repo = 'AbstractPhil/geolip-SVAE'
380
+ hf_version = 'v16_johanna_omega'
381
+ tb_dir = '/content/runs'
382
+
383
+ # ── Pretrained checkpoint ──
384
+ pretrained_repo = 'AbstractPhil/geolip-SVAE'
385
+ pretrained_file = 'v14_noise/checkpoints/epoch_0200.pt'
386
+
387
+ os.makedirs(save_dir, exist_ok=True)
388
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
389
+
390
+ # ── TensorBoard ──
391
+ from torch.utils.tensorboard import SummaryWriter
392
+ run_name = f"johanna_omega_V{V}_D{D}_h{hidden}_d{depth}"
393
+ tb_path = os.path.join(tb_dir, run_name)
394
+ writer = SummaryWriter(tb_path)
395
+ print(f" TensorBoard: {tb_path}")
396
+
397
+ # ── HuggingFace ──
398
+ hf_enabled = False
399
+ try:
400
+ from huggingface_hub import HfApi, hf_hub_download
401
+ api = HfApi()
402
+ api.whoami()
403
+ hf_enabled = True
404
+ hf_prefix = f"{hf_version}/checkpoints"
405
+ print(f" HuggingFace: {hf_repo}/{hf_prefix}")
406
+ except Exception as e:
407
+ print(f" HuggingFace: disabled ({e})")
408
+
409
+ def upload_to_hf(local_path, remote_name):
410
+ if not hf_enabled:
411
+ return
412
+ try:
413
+ api.upload_file(path_or_fileobj=local_path,
414
+ path_in_repo=f"{hf_prefix}/{remote_name}",
415
+ repo_id=hf_repo, repo_type="model")
416
+ print(f" ☁️ Uploaded: {hf_repo}/{hf_prefix}/{remote_name}")
417
+ except Exception as e:
418
+ print(f" ⚠️ HF upload failed: {e}")
419
+
420
+ # ── Load pretrained Johanna-128 Gaussian ──
421
+ print(f"\n Loading pretrained: {pretrained_repo}/{pretrained_file}")
422
+ ckpt_path = hf_hub_download(repo_id=pretrained_repo, filename=pretrained_file,
423
+ repo_type="model")
424
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
425
+ print(f" Pretrained epoch: {ckpt['epoch']}, MSE: {ckpt['test_mse']:.6f}")
426
+ print(f" Pretrained config: {ckpt['config']}")
427
+
428
+ # ── Model ──
429
+ model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size,
430
+ hidden=hidden, depth=depth,
431
+ n_cross_layers=n_cross_layers).to(device)
432
+ model.load_state_dict(ckpt['model_state_dict'], strict=True)
433
+ print(f" Loaded {sum(p.numel() for p in model.parameters()):,} parameters")
434
+
435
+ # Fresh optimizer β€” don't carry Gaussian momentum into omega training
436
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
437
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
438
+
439
+ # ── Data: 16-type omega noise at 128Γ—128 ──
440
+ train_ds = OmegaNoiseDataset(size=1280000, img_size=img_size)
441
+ val_ds = OmegaNoiseDataset(size=10000, img_size=img_size)
442
+ train_loader = torch.utils.data.DataLoader(
443
+ train_ds, batch_size=batch_size, shuffle=True,
444
+ num_workers=4, pin_memory=True, drop_last=True)
445
+ test_loader = torch.utils.data.DataLoader(
446
+ val_ds, batch_size=batch_size, shuffle=False,
447
+ num_workers=4, pin_memory=True)
448
+
449
+ n_patches = (img_size // patch_size) ** 2
450
+ batches_per_epoch = len(train_loader)
451
+ total_params = sum(p.numel() for p in model.parameters())
452
+
453
+ print(f"\n JOHANNA-128 OMEGA CONTINUATION")
454
+ print(f" Pretrained on: Gaussian N(0,1), 200 epochs, MSE=0.059")
455
+ print(f" Now training on: 16 noise types at {img_size}Γ—{img_size}")
456
+ print(f" {n_patches} patches, ({V},{D}), hidden={hidden}, depth={depth}")
457
+ print(f" Params: {total_params:,}, batch={batch_size}")
458
+ print(f" Batches/epoch: {batches_per_epoch}, lr={lr}")
459
+ print(f" Report every {report_every} batches")
460
+ print("=" * 100)
461
+ print(f" {'ep':>3} {'batch':>8} | {'loss':>7} {'recon':>7} | "
462
+ f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | "
463
+ f"{'row_cv':>7} {'prox':>5} {'rw':>5} | "
464
+ f"{'S_delta':>7}")
465
+ print("-" * 100)
466
+
467
+ best_recon = float('inf')
468
+ global_batch = 0
469
+
470
+ def save_checkpoint(path, epoch, test_mse, extra=None, upload=True):
471
+ ckpt_out = {
472
+ 'epoch': epoch, 'test_mse': test_mse,
473
+ 'global_batch': global_batch,
474
+ 'model_state_dict': model.state_dict(),
475
+ 'optimizer_state_dict': opt.state_dict(),
476
+ 'scheduler_state_dict': sched.state_dict(),
477
+ 'config': {
478
+ 'V': V, 'D': D, 'patch_size': patch_size,
479
+ 'hidden': hidden, 'depth': depth,
480
+ 'n_cross_layers': n_cross_layers,
481
+ 'target_cv': target_cv,
482
+ 'dataset': 'omega_noise_16types_128',
483
+ 'pretrained_from': 'v14_noise/epoch_0200.pt',
484
+ 'img_size': img_size, 'lr': lr,
485
+ },
486
+ }
487
+ if extra:
488
+ ckpt_out.update(extra)
489
+ torch.save(ckpt_out, path)
490
+ size_mb = os.path.getsize(path) / (1024 * 1024)
491
+ print(f" πŸ’Ύ Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})")
492
+ if upload:
493
+ upload_to_hf(path, os.path.basename(path))
494
+
495
+ # ── Training Loop ──
496
+ for epoch in range(1, epochs + 1):
497
+ model.train()
498
+ total_loss, total_recon, n = 0, 0, 0
499
+ last_cv, last_prox, recon_w = target_cv, 1.0, 1.0 + boost
500
+ t0 = time.time()
501
+
502
+ pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
503
+ bar_format='{l_bar}{bar:20}{r_bar}')
504
+ for batch_idx, (images, noise_types) in enumerate(pbar):
505
+ images = images.to(device)
506
+ opt.zero_grad()
507
+ out = model(images)
508
+ recon_loss = F.mse_loss(out['recon'], images)
509
+
510
+ with torch.no_grad():
511
+ if batch_idx % 50 == 0:
512
+ current_cv = cv_of(out['svd']['M'][0, 0])
513
+ if current_cv > 0:
514
+ last_cv = current_cv
515
+ delta = last_cv - target_cv
516
+ last_prox = math.exp(-delta**2 / (2 * sigma**2))
517
+
518
+ recon_w = 1.0 + boost * last_prox
519
+ cv_pen = cv_weight * (1.0 - last_prox)
520
+ cv_l = (last_cv - target_cv) ** 2
521
+ loss = recon_w * recon_loss + cv_pen * cv_l
522
+ loss.backward()
523
+
524
+ torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5)
525
+ opt.step()
526
+
527
+ total_loss += loss.item() * len(images)
528
+ total_recon += recon_loss.item() * len(images)
529
+ n += len(images)
530
+ global_batch += 1
531
+
532
+ pbar.set_postfix_str(
533
+ f"loss={recon_loss.item():.4f} cv={last_cv:.3f} prox={last_prox:.2f}",
534
+ refresh=False)
535
+
536
+ # ── Readout ──
537
+ if global_batch % report_every == 0:
538
+ model.eval()
539
+ with torch.no_grad():
540
+ test_imgs, _ = next(iter(test_loader))
541
+ test_imgs = test_imgs.to(device)
542
+ test_out = model(test_imgs)
543
+ test_mse = F.mse_loss(test_out['recon'], test_imgs).item()
544
+ S_mean = test_out['svd']['S'].mean(dim=(0, 1))
545
+ S_orig = test_out['svd']['S_orig'].mean(dim=(0, 1))
546
+ erank = model.effective_rank(
547
+ test_out['svd']['S'].reshape(-1, D)).mean().item()
548
+ s_delta = (S_mean - S_orig).abs().mean().item()
549
+ ratio = (S_mean[0] / (S_mean[-1] + 1e-8)).item()
550
+
551
+ writer.add_scalar('train/recon', total_recon / n, global_batch)
552
+ writer.add_scalar('test/recon_mse', test_mse, global_batch)
553
+ writer.add_scalar('geo/row_cv', last_cv, global_batch)
554
+ writer.add_scalar('geo/ratio', ratio, global_batch)
555
+ writer.add_scalar('geo/erank', erank, global_batch)
556
+ writer.add_scalar('geo/S0', S_mean[0].item(), global_batch)
557
+ writer.add_scalar('cross_attn/s_delta', s_delta, global_batch)
558
+
559
+ print(f"\n {epoch:3d} {global_batch:8d} | "
560
+ f"{total_loss/n:7.4f} {total_recon/n:7.4f} | "
561
+ f"{S_mean[0]:6.3f} {S_mean[-1]:6.3f} {ratio:5.2f} {erank:5.2f} | "
562
+ f"{last_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f} | "
563
+ f"{s_delta:7.5f}")
564
+
565
+ if test_mse < best_recon:
566
+ best_recon = test_mse
567
+ save_checkpoint(os.path.join(save_dir, 'best.pt'),
568
+ epoch, test_mse, upload=False)
569
+ model.train()
570
+
571
+ pbar.close()
572
+ sched.step()
573
+ epoch_time = time.time() - t0
574
+
575
+ writer.add_scalar('train/epoch_time', epoch_time, epoch)
576
+
577
+ # ── Epoch eval ──
578
+ model.eval()
579
+ test_recon_total, test_n = 0, 0
580
+ with torch.no_grad():
581
+ for test_imgs, _ in test_loader:
582
+ test_imgs = test_imgs.to(device)
583
+ out = model(test_imgs)
584
+ test_recon_total += F.mse_loss(out['recon'], test_imgs).item() * len(test_imgs)
585
+ test_n += len(test_imgs)
586
+ epoch_test_mse = test_recon_total / test_n
587
+
588
+ print(f" Epoch {epoch} done: {epoch_time:.1f}s, test_mse={epoch_test_mse:.6f}, "
589
+ f"best={best_recon:.6f}")
590
+
591
+ if epoch_test_mse < best_recon:
592
+ best_recon = epoch_test_mse
593
+ save_checkpoint(os.path.join(save_dir, 'best.pt'),
594
+ epoch, epoch_test_mse, upload=False)
595
+
596
+ if epoch % save_every == 0:
597
+ save_checkpoint(os.path.join(save_dir, f'epoch_{epoch:04d}.pt'),
598
+ epoch, epoch_test_mse)
599
+ best_path = os.path.join(save_dir, 'best.pt')
600
+ if os.path.exists(best_path):
601
+ upload_to_hf(best_path, 'best.pt')
602
+ writer.flush()
603
+ if hf_enabled:
604
+ try:
605
+ api.upload_folder(folder_path=tb_path,
606
+ path_in_repo=f"{hf_version}/tensorboard/{run_name}",
607
+ repo_id=hf_repo, repo_type="model")
608
+ print(f" ☁️ TB synced")
609
+ except:
610
+ pass
611
+
612
+ writer.close()
613
+ print(f"\n JOHANNA-128 OMEGA TRAINING COMPLETE")
614
+ print(f" Best MSE: {best_recon:.6f}")
615
+ print(f" Checkpoints: {save_dir}/")
616
+
617
+
618
+ if __name__ == "__main__":
619
+ torch.set_float32_matmul_precision('high')
620
+ train()