AbstractPhil commited on
Commit
a768141
Β·
verified Β·
1 Parent(s): 14efb7e

Create prototype_v9_prod.py

Browse files
Files changed (1) hide show
  1. prototype_v9_prod.py +443 -0
prototype_v9_prod.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SVAE β€” SVD Autoencoder with Geometric Attractors
3
+ ===================================================
4
+ A matrix-valued autoencoder where the latent space is a (V, D) matrix
5
+ decomposed by SVD. Rows are normalized to S^(D-1), making the geometric
6
+ structure architectural rather than loss-dependent.
7
+
8
+ Two key mechanisms:
9
+ 1. Sphere normalization: F.normalize(M, dim=-1) constrains rows to unit
10
+ vectors on S^(D-1). This bounds the Gram matrix, eliminates training
11
+ instabilities, and makes the CV a structural property of (V, D).
12
+ 2. Soft hand: An oscillatory counterweight that boosts reconstruction
13
+ gradients when geometry is near target, and penalizes CV drift when
14
+ geometry is far from target. Provides positive momentum, not just penalty.
15
+
16
+ Architecture: Image β†’ MLP β†’ M ∈ ℝ^(VΓ—D) β†’ normalize β†’ SVD β†’ MLP β†’ Recon
17
+
18
+ Repository: AbstractEyes/geolip-core
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torchvision
25
+ import torchvision.transforms as T
26
+ import math
27
+ import time
28
+
29
+ # ── SVD Backend ──────────────────────────────────────────────────
30
+
31
+ try:
32
+ from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N
33
+ HAS_FL = True
34
+ except ImportError:
35
+ HAS_FL = False
36
+
37
+
38
+ def gram_eigh_svd_fp64(A):
39
+ """Thin SVD via Gram matrix + eigh, computed entirely in fp64.
40
+
41
+ fp64 is essential: Gram entries scale as Sβ‚€Β², and fp32 (~7 digits)
42
+ causes catastrophic collapses when the condition number exceeds ~100.
43
+ fp64 (~15 digits) eliminates this failure mode entirely.
44
+
45
+ Args:
46
+ A: (B, M, N) tensor, M >= N
47
+ Returns:
48
+ U (B,M,N), S (B,N), Vh (B,N,N) β€” singular values descending.
49
+ """
50
+ orig_dtype = A.dtype
51
+ with torch.amp.autocast('cuda', enabled=False):
52
+ A_d = A.double()
53
+ G = torch.bmm(A_d.transpose(1, 2), A_d)
54
+ eigenvalues, V = torch.linalg.eigh(G)
55
+ eigenvalues = eigenvalues.flip(-1)
56
+ V = V.flip(-1)
57
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
58
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
59
+ Vh = V.transpose(-2, -1).contiguous()
60
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
61
+
62
+
63
+ def svd_fp64(A):
64
+ """Auto-dispatch SVD with fp64 internals.
65
+
66
+ N <= 12 + FLEigh available: Gram in fp64, FL eigh (compilable).
67
+ N > 12 or CPU: Gram + torch.linalg.eigh in fp64.
68
+ Triton bypassed β€” fp32-only hardware, incompatible with fp64.
69
+ """
70
+ B, M, N = A.shape
71
+ if HAS_FL and N <= _FL_MAX_N and A.is_cuda:
72
+ orig_dtype = A.dtype
73
+ with torch.amp.autocast('cuda', enabled=False):
74
+ A_d = A.double()
75
+ G = torch.bmm(A_d.transpose(1, 2), A_d)
76
+ eigenvalues, V = FLEigh()(G.float()) # FL needs fp32 input
77
+ eigenvalues = eigenvalues.double().flip(-1)
78
+ V = V.double().flip(-1)
79
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
80
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
81
+ Vh = V.transpose(-2, -1).contiguous()
82
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
83
+ else:
84
+ return gram_eigh_svd_fp64(A)
85
+
86
+
87
+ # ── Cayley-Menger CV Monitoring ──────────────────────────────────
88
+
89
+ def cayley_menger_vol2(points):
90
+ """Squared simplex volume via Cayley-Menger determinant, in fp64.
91
+ Args: points (B, N, D) β€” B simplices, each with N vertices in D dims.
92
+ Returns: (B,) squared volumes.
93
+ """
94
+ B, N, D = points.shape
95
+ pts = points.double()
96
+ gram = torch.bmm(pts, pts.transpose(1, 2))
97
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
98
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
99
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
100
+ cm[:, 0, 1:] = 1.0
101
+ cm[:, 1:, 0] = 1.0
102
+ cm[:, 1:, 1:] = d2
103
+ k = N - 1
104
+ sign = (-1.0) ** (k + 1)
105
+ fact = math.factorial(k)
106
+ return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
107
+
108
+
109
+ def cv_of(emb, n_samples=200):
110
+ """CV of pentachoron volumes for a single embedding matrix.
111
+ Measures geometric regularity: low CV = regular, high CV = irregular.
112
+ Args: emb (V, D) tensor.
113
+ Returns: float CV value, or 0.0 if insufficient data.
114
+ """
115
+ if emb.dim() != 2 or emb.shape[0] < 5:
116
+ return 0.0
117
+ N, D = emb.shape
118
+ pool = min(N, 512)
119
+ indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)])
120
+ vol2 = cayley_menger_vol2(emb[:pool][indices])
121
+ valid = vol2 > 1e-20
122
+ if valid.sum() < 10:
123
+ return 0.0
124
+ vols = vol2[valid].sqrt()
125
+ return (vols.std() / (vols.mean() + 1e-8)).item()
126
+
127
+
128
+ # ── Data ─────────────────────────────────────────────────────────
129
+
130
+ def get_cifar10(batch_size=256):
131
+ transform = T.Compose([
132
+ T.ToTensor(),
133
+ T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
134
+ ])
135
+ train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
136
+ test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
137
+ train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
138
+ test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
139
+ return train_loader, test_loader
140
+
141
+
142
+ # ── SVAE Model ───────────────────────────────────────────────────
143
+
144
+ class SVAE(nn.Module):
145
+ """SVD Autoencoder with sphere-normalized matrix latent space.
146
+
147
+ The encoder produces a (V, D) matrix whose rows are normalized to S^(D-1).
148
+ The SVD decomposes alignment structure (U, V) from spectral magnitudes (S).
149
+ The decoder reconstructs from the full SVD: MΜ‚ = UΞ£Vα΅€.
150
+
151
+ Args:
152
+ matrix_v: Number of rows V (vocabulary size / overcomplete factor)
153
+ D: Embedding dimension (number of singular values)
154
+ """
155
+ def __init__(self, matrix_v=96, D=24):
156
+ super().__init__()
157
+ self.matrix_v = matrix_v
158
+ self.D = D
159
+ self.img_dim = 3 * 32 * 32
160
+ self.mat_dim = matrix_v * D
161
+
162
+ self.encoder = nn.Sequential(
163
+ nn.Linear(self.img_dim, 512),
164
+ nn.GELU(),
165
+ nn.Linear(512, 512),
166
+ nn.GELU(),
167
+ nn.Linear(512, self.mat_dim),
168
+ )
169
+ self.decoder = nn.Sequential(
170
+ nn.Linear(self.mat_dim, 512),
171
+ nn.GELU(),
172
+ nn.Linear(512, 512),
173
+ nn.GELU(),
174
+ nn.Linear(512, self.img_dim),
175
+ )
176
+ nn.init.orthogonal_(self.encoder[-1].weight)
177
+
178
+ def encode(self, images):
179
+ B = images.shape[0]
180
+ M = self.encoder(images.reshape(B, -1)).reshape(B, self.matrix_v, self.D)
181
+ M = F.normalize(M, dim=-1) # rows to S^(D-1)
182
+ U, S, Vh = svd_fp64(M)
183
+ return {'U': U, 'S': S, 'Vt': Vh, 'M': M}
184
+
185
+ def decode_from_svd(self, U, S, Vt):
186
+ B = U.shape[0]
187
+ M_hat = torch.bmm(U * S.unsqueeze(1), Vt)
188
+ return self.decoder(M_hat.reshape(B, -1)).reshape(B, 3, 32, 32)
189
+
190
+ def forward(self, images):
191
+ svd = self.encode(images)
192
+ recon = self.decode_from_svd(svd['U'], svd['S'], svd['Vt'])
193
+ return {'recon': recon, 'svd': svd}
194
+
195
+ @staticmethod
196
+ def effective_rank(S):
197
+ """Shannon entropy effective rank of singular value spectrum."""
198
+ p = S / (S.sum(-1, keepdim=True) + 1e-8)
199
+ p = p.clamp(min=1e-8)
200
+ return (-(p * p.log()).sum(-1)).exp()
201
+
202
+
203
+ # ── Training ─────────────────────────────────────────────────────
204
+
205
+ def train(epochs=100, lr=1e-3, V=256, D=24, target_cv=0.125,
206
+ cv_weight=0.3, boost=0.5, sigma=0.15, device='cuda'):
207
+ """Train the SVAE with sphere normalization + soft hand.
208
+
209
+ Args:
210
+ epochs: Training epochs
211
+ lr: Learning rate for Adam
212
+ V: Matrix rows (vocabulary size)
213
+ D: Embedding dimension
214
+ target_cv: CV attractor target for soft hand
215
+ cv_weight: Maximum CV penalty weight (far from target)
216
+ boost: Maximum reconstruction boost factor (near target)
217
+ sigma: Gaussian transition width for proximity
218
+ device: Training device
219
+ """
220
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
221
+ train_loader, test_loader = get_cifar10(batch_size=256)
222
+
223
+ model = SVAE(matrix_v=V, D=D).to(device)
224
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
225
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
226
+
227
+ total_params = sum(p.numel() for p in model.parameters())
228
+
229
+ # ── Header ──
230
+ svd_backend = f"fp64 Gram+eigh (FL={'available, N<=12' if HAS_FL else 'not available'})"
231
+ print(f"Using geolip-core SVD ({svd_backend})")
232
+ print(f"SVAE - V={V}, D={D}, rows on S^{D-1} + soft hand")
233
+ print(f" Matrix: ({V}, {D}) = {V*D} elements, rows normalized")
234
+ print(f" SVD: fp64 Gram+eigh")
235
+ print(f" Sphere: rows on S^{D-1} (structural geometry)")
236
+ print(f" Soft hand: boost={1+boost:.1f}x near CV={target_cv}, penalty={cv_weight} far")
237
+ print(f" Params: {total_params:,}")
238
+ print("=" * 90)
239
+ print(f" {'ep':>3} | {'loss':>7} {'recon':>7} {'t/ep':>5} | "
240
+ f"{'t_rec':>7} | "
241
+ f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | "
242
+ f"{'row_cv':>7} {'prox':>5} {'rw':>5}")
243
+ print("-" * 90)
244
+
245
+ # ── Training loop ──
246
+ for epoch in range(1, epochs + 1):
247
+ model.train()
248
+ total_loss, total_recon, n = 0, 0, 0
249
+ last_cv = target_cv
250
+ last_prox = 1.0
251
+ recon_w = 1.0 + boost
252
+ t0 = time.time()
253
+
254
+ for batch_idx, (images, labels) in enumerate(train_loader):
255
+ images = images.to(device)
256
+ opt.zero_grad()
257
+ out = model(images)
258
+ recon_loss = F.mse_loss(out['recon'], images)
259
+
260
+ # Measure CV and compute proximity (every 10th batch)
261
+ with torch.no_grad():
262
+ if batch_idx % 10 == 0:
263
+ current_cv = cv_of(out['svd']['M'][0])
264
+ if current_cv > 0:
265
+ last_cv = current_cv
266
+ delta = last_cv - target_cv
267
+ last_prox = math.exp(-delta**2 / (2 * sigma**2))
268
+
269
+ # Soft hand: boost recon near target, penalize CV far from target
270
+ recon_w = 1.0 + boost * last_prox
271
+ cv_pen = cv_weight * (1.0 - last_prox)
272
+ cv_l = (last_cv - target_cv) ** 2
273
+
274
+ loss = recon_w * recon_loss + cv_pen * cv_l
275
+ loss.backward()
276
+ opt.step()
277
+
278
+ total_loss += loss.item() * len(images)
279
+ total_recon += recon_loss.item() * len(images)
280
+ n += len(images)
281
+
282
+ sched.step()
283
+ epoch_time = time.time() - t0
284
+
285
+ # ── Evaluation (every 2 epochs + first 3) ──
286
+ if epoch % 2 == 0 or epoch <= 3:
287
+ model.eval()
288
+ test_recon, test_n = 0, 0
289
+ test_S, test_erank = None, 0
290
+ row_cvs = []
291
+ nb = 0
292
+
293
+ with torch.no_grad():
294
+ for images, labels in test_loader:
295
+ images = images.to(device)
296
+ out = model(images)
297
+ test_recon += F.mse_loss(out['recon'], images).item() * len(images)
298
+ test_n += len(images)
299
+ test_erank += model.effective_rank(out['svd']['S']).mean().item()
300
+ if nb < 5:
301
+ for b in range(min(4, len(images))):
302
+ row_cvs.append(cv_of(out['svd']['M'][b]))
303
+ if test_S is None:
304
+ test_S = out['svd']['S'].mean(0).cpu()
305
+ else:
306
+ test_S += out['svd']['S'].mean(0).cpu()
307
+ nb += 1
308
+
309
+ test_erank /= nb
310
+ test_S /= nb
311
+ ratio = (test_S[0] / (test_S[-1] + 1e-8)).item()
312
+ mean_cv = sum(row_cvs) / len(row_cvs) if row_cvs else 0
313
+
314
+ print(f" {epoch:3d} | {total_loss/n:7.4f} {total_recon/n:7.4f} {epoch_time:5.1f} | "
315
+ f"{test_recon/test_n:7.4f} | "
316
+ f"{test_S[0]:6.3f} {test_S[-1]:6.3f} {ratio:5.2f} "
317
+ f"{test_erank:5.2f} | "
318
+ f"{mean_cv:7.4f} {last_prox:5.3f} {recon_w:5.2f}")
319
+
320
+ # ── Final Analysis ──
321
+ print()
322
+ print("=" * 85)
323
+ print("FINAL ANALYSIS")
324
+ print("=" * 85)
325
+
326
+ model.eval()
327
+ all_S, all_recon_err, all_labels = [], [], []
328
+ all_row_cvs = []
329
+
330
+ with torch.no_grad():
331
+ for images, labels in test_loader:
332
+ images = images.to(device)
333
+ out = model(images)
334
+ all_S.append(out['svd']['S'].cpu())
335
+ all_recon_err.append(
336
+ F.mse_loss(out['recon'], images, reduction='none')
337
+ .mean(dim=(1, 2, 3)).cpu())
338
+ all_labels.append(labels.cpu())
339
+ for b in range(min(8, len(images))):
340
+ all_row_cvs.append(cv_of(out['svd']['M'][b]))
341
+
342
+ all_S = torch.cat(all_S)
343
+ all_recon_err = torch.cat(all_recon_err)
344
+ all_labels = torch.cat(all_labels)
345
+ erank = model.effective_rank(all_S)
346
+ mean_cv = sum(all_row_cvs) / len(all_row_cvs)
347
+
348
+ print(f"\n V={V}, D={D}, rows on S^{D-1}")
349
+ print(f" Target CV: {target_cv}")
350
+ print(f" Recon MSE: {all_recon_err.mean():.6f} +/- {all_recon_err.std():.6f}")
351
+ print(f" Effective rank: {erank.mean():.2f} +/- {erank.std():.2f}")
352
+ print(f" Row CV: {mean_cv:.4f}")
353
+
354
+ S_mean = all_S.mean(0)
355
+ total_energy = (S_mean ** 2).sum()
356
+ print(f"\n Singular value profile:")
357
+ cumulative = 0
358
+ for i in range(len(S_mean)):
359
+ e = (S_mean[i] ** 2).item()
360
+ cumulative += e
361
+ pct = cumulative / total_energy * 100
362
+ bar = "#" * int(S_mean[i].item() * 30 / (S_mean[0].item() + 1e-8))
363
+ print(f" S[{i:2d}]: {S_mean[i]:8.4f} cum={pct:5.1f}% {bar}")
364
+
365
+ cifar_names = ['plane', 'car', 'bird', 'cat', 'deer',
366
+ 'dog', 'frog', 'horse', 'ship', 'truck']
367
+ print(f"\n Per-class:")
368
+ print(f" {'cls':>6} {'recon':>8} {'erank':>6} {'S0':>7} {'SD':>7} {'ratio':>6}")
369
+ for c in range(10):
370
+ mask = all_labels == c
371
+ rc = all_recon_err[mask].mean().item()
372
+ er = erank[mask].mean().item()
373
+ s0 = all_S[mask, 0].mean().item()
374
+ sd = all_S[mask, -1].mean().item()
375
+ r = s0 / (sd + 1e-8)
376
+ print(f" {cifar_names[c]:>6} {rc:8.6f} {er:6.2f} {s0:7.4f} {sd:7.4f} {r:6.2f}")
377
+
378
+ # ── Reconstruction Grid ──
379
+ print(f"\n Saving reconstruction grid...")
380
+ import matplotlib
381
+ matplotlib.use('Agg')
382
+ import matplotlib.pyplot as plt
383
+
384
+ mean_t = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(1, 3, 1, 1).to(device)
385
+ std_t = torch.tensor([0.2470, 0.2435, 0.2616]).reshape(1, 3, 1, 1).to(device)
386
+
387
+ model.eval()
388
+ with torch.no_grad():
389
+ images, labels = next(iter(test_loader))
390
+ images = images.to(device)
391
+ out = model(images)
392
+
393
+ selected_idx = []
394
+ for c in range(10):
395
+ class_idx = (labels == c).nonzero(as_tuple=True)[0]
396
+ selected_idx.extend(class_idx[:2].tolist())
397
+
398
+ orig = images[selected_idx]
399
+ U = out['svd']['U'][selected_idx]
400
+ S = out['svd']['S'][selected_idx]
401
+ Vt = out['svd']['Vt'][selected_idx]
402
+
403
+ mode_counts = [1, 4, 8, 16, D]
404
+ prog_recons = []
405
+ for nm in mode_counts:
406
+ r = model.decode_from_svd(U[:, :, :nm], S[:, :nm], Vt[:, :nm, :])
407
+ prog_recons.append(r)
408
+
409
+ def denorm(t):
410
+ return (t * std_t + mean_t).clamp(0, 1).cpu()
411
+
412
+ n_samples = len(selected_idx)
413
+ n_cols = 2 + len(mode_counts)
414
+ fig, axes = plt.subplots(n_samples, n_cols, figsize=(n_cols * 1.5, n_samples * 1.5))
415
+ col_titles = ['Original'] + [f'{m} modes' for m in mode_counts] + ['|Err|x5']
416
+
417
+ for i in range(n_samples):
418
+ axes[i, 0].imshow(denorm(orig[i:i+1])[0].permute(1, 2, 0).numpy())
419
+ for j, r in enumerate(prog_recons):
420
+ axes[i, j+1].imshow(denorm(r[i:i+1])[0].permute(1, 2, 0).numpy())
421
+ err_col = 1 + len(prog_recons)
422
+ diff = (denorm(orig[i:i+1]) - denorm(prog_recons[-1][i:i+1])).abs() * 5
423
+ axes[i, err_col].imshow(diff.clamp(0, 1)[0].permute(1, 2, 0).numpy())
424
+ c = labels[selected_idx[i]].item()
425
+ axes[i, 0].set_ylabel(cifar_names[c], fontsize=8, rotation=0, labelpad=35)
426
+
427
+ for j, title in enumerate(col_titles):
428
+ axes[0, j].set_title(title, fontsize=8)
429
+ for ax in axes.flat:
430
+ ax.axis('off')
431
+
432
+ plt.tight_layout()
433
+ plt.savefig('/content/svae_recon_grid.png', dpi=200, bbox_inches='tight')
434
+ print(f" Saved to /content/svae_recon_grid.png")
435
+ try:
436
+ plt.show()
437
+ except:
438
+ pass
439
+ plt.close()
440
+
441
+
442
+ if __name__ == "__main__":
443
+ train()