AbstractPhil commited on
Commit
a45b0f0
Β·
verified Β·
1 Parent(s): 5fb8de8

Create omega_processor_test_cifar10_noise_model.py

Browse files
omega_processor_test_cifar10_noise_model.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Omega Processor β€” CIFAR-10 Image Classification
3
+ ==================================================
4
+ Freckles (trained on NOISE, frozen) β†’ SVD β†’ Geometric Features β†’ Transformer β†’ 10 classes
5
+
6
+ The ultimate test: can a noise-trained spectral decomposition
7
+ produce useful features for real image classification?
8
+
9
+ CIFAR-10 32Γ—32 β†’ bilinear resize to 64Γ—64 β†’ Freckles β†’ features β†’ classify
10
+
11
+ Usage:
12
+ python omega_cifar10.py
13
+ """
14
+
15
+ import os
16
+ import math
17
+ import time
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import numpy as np
22
+ from tqdm import tqdm
23
+
24
+ try:
25
+ from google.colab import userdata
26
+ os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
27
+ from huggingface_hub import login
28
+ login(token=os.environ["HF_TOKEN"])
29
+ except Exception:
30
+ pass
31
+
32
+
33
+ # ═══════════════════════════════════════════════════════════════
34
+ # GEOMETRIC FEATURE EXTRACTOR (same as omega_processor.py)
35
+ # ═══════════════════════════════════════════════════════════════
36
+
37
+ class GeometricFeatureExtractor(nn.Module):
38
+ def __init__(self, D=4, V=48):
39
+ super().__init__()
40
+ self.D = D
41
+ self.V = V
42
+ self.register_buffer('m_proj', torch.randn(V, 8) / math.sqrt(V))
43
+
44
+ def forward(self, svd_dict, gh, gw):
45
+ S = svd_dict['S']
46
+ S_orig = svd_dict['S_orig']
47
+ U = svd_dict['U']
48
+ Vt = svd_dict['Vt']
49
+ M = svd_dict['M']
50
+
51
+ B, N, D = S.shape
52
+ features = []
53
+
54
+ # Tier 1: Scalar (16 dims)
55
+ S_ratios = S[:, :, :-1] / (S[:, :, 1:] + 1e-8)
56
+ features.append(S_ratios)
57
+
58
+ S2 = S.pow(2)
59
+ energy = S2 / (S2.sum(-1, keepdim=True) + 1e-8)
60
+ features.append(energy)
61
+
62
+ p = S / (S.sum(-1, keepdim=True) + 1e-8)
63
+ p = p.clamp(min=1e-8)
64
+ erank = (-(p * p.log()).sum(-1, keepdim=True)).exp()
65
+ features.append(erank / D)
66
+
67
+ cond = (S[:, :, 0:1] / (S[:, :, -1:] + 1e-8))
68
+ features.append(cond / 10.0)
69
+
70
+ S_delta = S - S_orig
71
+ features.append(S_delta)
72
+
73
+ S_log = torch.log(S[:, :, :-1] + 1e-8) - torch.log(S[:, :, 1:] + 1e-8)
74
+ features.append(S_log)
75
+
76
+ # Tier 2: Relational (16 dims)
77
+ S_grid = S.reshape(B, gh, gw, D)
78
+ padded = F.pad(S_grid.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='reflect')
79
+ neighbor_sum = (padded[:, :, :-2, 1:-1] + padded[:, :, 2:, 1:-1] +
80
+ padded[:, :, 1:-1, :-2] + padded[:, :, 1:-1, 2:]) / 4
81
+ S_center = S_grid.permute(0, 3, 1, 2)
82
+ delta_card = (S_center - neighbor_sum).permute(0, 2, 3, 1).reshape(B, N, D)
83
+ features.append(delta_card)
84
+
85
+ neighbor_sq = (padded[:, :, :-2, 1:-1].pow(2) + padded[:, :, 2:, 1:-1].pow(2) +
86
+ padded[:, :, 1:-1, :-2].pow(2) + padded[:, :, 1:-1, 2:].pow(2)) / 4
87
+ neighbor_var = (neighbor_sq - neighbor_sum.pow(2)).clamp(min=0)
88
+ neighbor_std = neighbor_var.sqrt().permute(0, 2, 3, 1).reshape(B, N, D)
89
+ features.append(neighbor_std)
90
+
91
+ energy_grid = energy.reshape(B, gh, gw, D).permute(0, 3, 1, 2)
92
+ e_padded = F.pad(energy_grid, (1, 1, 1, 1), mode='reflect')
93
+ e_neighbor = (e_padded[:, :, :-2, 1:-1] + e_padded[:, :, 2:, 1:-1] +
94
+ e_padded[:, :, 1:-1, :-2] + e_padded[:, :, 1:-1, 2:]) / 4
95
+ e_delta = (energy_grid - e_neighbor).permute(0, 2, 3, 1).reshape(B, N, D)
96
+ features.append(e_delta)
97
+
98
+ rows = torch.arange(gh, device=S.device).float() / gh
99
+ cols = torch.arange(gw, device=S.device).float() / gw
100
+ row_grid = rows.unsqueeze(1).expand(gh, gw).reshape(1, N, 1).expand(B, -1, -1)
101
+ col_grid = cols.unsqueeze(0).expand(gh, gw).reshape(1, N, 1).expand(B, -1, -1)
102
+ features.append(torch.sin(row_grid * math.pi))
103
+ features.append(torch.cos(col_grid * math.pi))
104
+ features.append(torch.sin(row_grid * 2 * math.pi))
105
+ features.append(torch.cos(col_grid * 2 * math.pi))
106
+
107
+ # Tier 3: Basis (32 dims)
108
+ Vt_flat = Vt.reshape(B, N, D * D)
109
+ features.append(Vt_flat)
110
+
111
+ U_col_mean = U.mean(dim=2)
112
+ U_col_std = U.std(dim=2)
113
+ features.append(U_col_mean)
114
+ features.append(U_col_std)
115
+
116
+ M_sketch = torch.einsum('bnvd,vk->bnk', M, self.m_proj)
117
+ features.append(M_sketch)
118
+
119
+ return torch.cat(features, dim=-1)
120
+
121
+
122
+ # ═══════════════════════════════════════════════════════════════
123
+ # OMEGA TRANSFORMER CLASSIFIER
124
+ # ═══════════════════════════════════════════════════════════════
125
+
126
+ class OmegaTransformerClassifier(nn.Module):
127
+ def __init__(self, feat_dim=64, d_model=128, n_heads=4,
128
+ n_layers=4, n_classes=10, dropout=0.1, D=4, V=48):
129
+ super().__init__()
130
+ self.feat_extractor = GeometricFeatureExtractor(D=D, V=V)
131
+
132
+ self.input_proj = nn.Sequential(
133
+ nn.LayerNorm(feat_dim),
134
+ nn.Linear(feat_dim, d_model),
135
+ nn.GELU(),
136
+ nn.Linear(d_model, d_model),
137
+ )
138
+
139
+ self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
140
+
141
+ encoder_layer = nn.TransformerEncoderLayer(
142
+ d_model=d_model, nhead=n_heads,
143
+ dim_feedforward=d_model * 4,
144
+ dropout=dropout, batch_first=True,
145
+ activation='gelu',
146
+ )
147
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
148
+
149
+ self.head = nn.Sequential(
150
+ nn.LayerNorm(d_model),
151
+ nn.Linear(d_model, d_model),
152
+ nn.GELU(),
153
+ nn.Dropout(dropout),
154
+ nn.Linear(d_model, n_classes),
155
+ )
156
+
157
+ def forward(self, svd_dict, gh, gw):
158
+ features = self.feat_extractor(svd_dict, gh, gw)
159
+ B, N, F = features.shape
160
+ tokens = self.input_proj(features)
161
+ cls = self.cls_token.expand(B, -1, -1)
162
+ tokens = torch.cat([cls, tokens], dim=1)
163
+ out = self.transformer(tokens)
164
+ return self.head(out[:, 0])
165
+
166
+
167
+ # ═══════════════════════════════════════════════════════════════
168
+ # RAW PATCH BASELINE
169
+ # ═══════════════════════════════════════════════════════════════
170
+
171
+ class RawPatchClassifier(nn.Module):
172
+ def __init__(self, patch_dim=48, d_model=128, n_heads=4,
173
+ n_layers=4, n_classes=10, dropout=0.1, n_patches=256):
174
+ super().__init__()
175
+ self.input_proj = nn.Sequential(
176
+ nn.LayerNorm(patch_dim),
177
+ nn.Linear(patch_dim, d_model),
178
+ nn.GELU(),
179
+ nn.Linear(d_model, d_model),
180
+ )
181
+ self.pos_enc = nn.Parameter(torch.randn(1, n_patches + 1, d_model) * 0.02)
182
+ self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
183
+
184
+ encoder_layer = nn.TransformerEncoderLayer(
185
+ d_model=d_model, nhead=n_heads,
186
+ dim_feedforward=d_model * 4,
187
+ dropout=dropout, batch_first=True,
188
+ activation='gelu',
189
+ )
190
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
191
+ self.head = nn.Sequential(
192
+ nn.LayerNorm(d_model),
193
+ nn.Linear(d_model, d_model),
194
+ nn.GELU(),
195
+ nn.Dropout(dropout),
196
+ nn.Linear(d_model, n_classes),
197
+ )
198
+
199
+ def forward(self, images):
200
+ B, C, H, W = images.shape
201
+ ps = 4
202
+ gh, gw = H // ps, W // ps
203
+ N = gh * gw
204
+ patches = images.reshape(B, C, gh, ps, gw, ps)
205
+ patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, N, C * ps * ps)
206
+ tokens = self.input_proj(patches)
207
+ cls = self.cls_token.expand(B, -1, -1)
208
+ tokens = torch.cat([cls, tokens], dim=1)
209
+ tokens = tokens + self.pos_enc[:, :N + 1]
210
+ out = self.transformer(tokens)
211
+ return self.head(out[:, 0])
212
+
213
+
214
+ # ═══════════════════════════════════════════════════════════════
215
+ # CIFAR-10 DATASET
216
+ # ═══════════════════════════════════════════════════════════════
217
+
218
+ CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
219
+ 'dog', 'frog', 'horse', 'ship', 'truck']
220
+
221
+ IMG_MEAN = (0.4914, 0.4822, 0.4465)
222
+ IMG_STD = (0.2470, 0.2435, 0.2616)
223
+
224
+
225
+ def get_cifar10_loaders(batch_size=128, img_size=64):
226
+ """Load CIFAR-10, resize to img_size, normalize."""
227
+ import torchvision
228
+ import torchvision.transforms as T
229
+
230
+ transform_train = T.Compose([
231
+ T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR),
232
+ T.RandomHorizontalFlip(),
233
+ T.ToTensor(),
234
+ T.Normalize(IMG_MEAN, IMG_STD),
235
+ ])
236
+ transform_test = T.Compose([
237
+ T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR),
238
+ T.ToTensor(),
239
+ T.Normalize(IMG_MEAN, IMG_STD),
240
+ ])
241
+
242
+ train_ds = torchvision.datasets.CIFAR10(
243
+ root='/content/data', train=True, download=True, transform=transform_train)
244
+ test_ds = torchvision.datasets.CIFAR10(
245
+ root='/content/data', train=False, download=True, transform=transform_test)
246
+
247
+ train_loader = torch.utils.data.DataLoader(
248
+ train_ds, batch_size=batch_size, shuffle=True,
249
+ num_workers=4, pin_memory=True, drop_last=True)
250
+ test_loader = torch.utils.data.DataLoader(
251
+ test_ds, batch_size=batch_size, shuffle=False,
252
+ num_workers=4, pin_memory=True)
253
+
254
+ return train_loader, test_loader
255
+
256
+
257
+ # ═══════════════════════════════════════════════════════════════
258
+ # TRAINING
259
+ # ═══════════════════════════════════════════════════════════════
260
+
261
+ def train_model(mode='omega', epochs=30, batch_size=128, lr=3e-4,
262
+ d_model=128, n_heads=4, n_layers=4, img_size=64,
263
+ device='cuda'):
264
+ """
265
+ mode: 'omega' (Freckles + features) or 'baseline' (raw patches)
266
+ """
267
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
268
+
269
+ print("\n" + "=" * 70)
270
+ if mode == 'omega':
271
+ print("OMEGA PROCESSOR β€” CIFAR-10 (Freckles features)")
272
+ else:
273
+ print("BASELINE β€” CIFAR-10 (Raw patches, no Freckles)")
274
+ print("=" * 70)
275
+
276
+ ps = 4
277
+ gh, gw = img_size // ps, img_size // ps
278
+ n_patches = gh * gw
279
+
280
+ # Load Freckles for omega mode
281
+ freckles = None
282
+ if mode == 'omega':
283
+ from geolip_svae import load_model
284
+ freckles, f_cfg = load_model(hf_version='v40_freckles_noise', device=device)
285
+ freckles.eval()
286
+ for p in freckles.parameters():
287
+ p.requires_grad = False
288
+ print(f" Freckles: {sum(p.numel() for p in freckles.parameters()):,} params (frozen)")
289
+
290
+ # Determine feature dim
291
+ with torch.no_grad():
292
+ dummy = torch.randn(1, 3, img_size, img_size).to(device)
293
+ dummy_out = freckles(dummy)
294
+ feat_ext = GeometricFeatureExtractor(D=f_cfg['D'], V=f_cfg['V']).to(device)
295
+ feat_dim = feat_ext(dummy_out['svd'], gh, gw).shape[-1]
296
+ del feat_ext
297
+ print(f" Feature dim: {feat_dim}")
298
+
299
+ classifier = OmegaTransformerClassifier(
300
+ feat_dim=feat_dim, d_model=d_model, n_heads=n_heads,
301
+ n_layers=n_layers, n_classes=10, D=f_cfg['D'], V=f_cfg['V'],
302
+ ).to(device)
303
+ else:
304
+ classifier = RawPatchClassifier(
305
+ patch_dim=3 * ps * ps, d_model=d_model, n_heads=n_heads,
306
+ n_layers=n_layers, n_classes=10, n_patches=n_patches,
307
+ ).to(device)
308
+
309
+ n_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
310
+ print(f" Classifier: {n_params:,} params")
311
+ print(f" Architecture: d_model={d_model}, heads={n_heads}, layers={n_layers}")
312
+ print(f" CIFAR-10: 50K train, 10K test, {img_size}Γ—{img_size}")
313
+ print(f" Batch: {batch_size}, lr={lr}, epochs={epochs}")
314
+ print("=" * 70)
315
+
316
+ train_loader, test_loader = get_cifar10_loaders(batch_size, img_size)
317
+
318
+ opt = torch.optim.Adam(classifier.parameters(), lr=lr)
319
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
320
+
321
+ best_acc = 0
322
+
323
+ for epoch in range(1, epochs + 1):
324
+ classifier.train()
325
+ total_loss, correct, total = 0, 0, 0
326
+ t0 = time.time()
327
+
328
+ pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
329
+ bar_format='{l_bar}{bar:20}{r_bar}')
330
+ for images, labels in pbar:
331
+ images = images.to(device)
332
+ labels = labels.to(device)
333
+
334
+ if mode == 'omega':
335
+ with torch.no_grad():
336
+ out = freckles(images)
337
+ logits = classifier(out['svd'], gh, gw)
338
+ else:
339
+ logits = classifier(images)
340
+
341
+ loss = F.cross_entropy(logits, labels)
342
+ opt.zero_grad()
343
+ loss.backward()
344
+ torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
345
+ opt.step()
346
+
347
+ total_loss += loss.item() * len(labels)
348
+ correct += (logits.argmax(-1) == labels).sum().item()
349
+ total += len(labels)
350
+ pbar.set_postfix_str(f"loss={loss.item():.4f} acc={correct/total:.1%}")
351
+
352
+ sched.step()
353
+ train_acc = correct / total
354
+ train_loss = total_loss / total
355
+
356
+ # Test
357
+ classifier.eval()
358
+ test_correct, test_total = 0, 0
359
+ per_class_correct = torch.zeros(10)
360
+ per_class_total = torch.zeros(10)
361
+
362
+ with torch.no_grad():
363
+ for images, labels in test_loader:
364
+ images = images.to(device)
365
+ labels = labels.to(device)
366
+
367
+ if mode == 'omega':
368
+ out = freckles(images)
369
+ logits = classifier(out['svd'], gh, gw)
370
+ else:
371
+ logits = classifier(images)
372
+
373
+ preds = logits.argmax(-1)
374
+ test_correct += (preds == labels).sum().item()
375
+ test_total += len(labels)
376
+
377
+ for c in range(10):
378
+ mask = labels == c
379
+ per_class_correct[c] += (preds[mask] == labels[mask]).sum().item()
380
+ per_class_total[c] += mask.sum().item()
381
+
382
+ test_acc = test_correct / test_total
383
+ epoch_time = time.time() - t0
384
+
385
+ per_class_acc = per_class_correct / (per_class_total + 1e-8)
386
+ worst_class = per_class_acc.argmin().item()
387
+ best_class = per_class_acc.argmax().item()
388
+
389
+ print(f" ep{epoch:3d} | loss={train_loss:.4f} train={train_acc:.1%} "
390
+ f"test={test_acc:.1%} | best={CIFAR_CLASSES[best_class]}={per_class_acc[best_class]:.0%} "
391
+ f"worst={CIFAR_CLASSES[worst_class]}={per_class_acc[worst_class]:.0%} | {epoch_time:.0f}s")
392
+
393
+ if test_acc > best_acc:
394
+ best_acc = test_acc
395
+
396
+ if epoch % 5 == 0 or epoch == 1 or epoch == epochs:
397
+ print(f"\n {'class':<14s} {'acc':>6s}")
398
+ print(f" {'-'*22}")
399
+ for c in range(10):
400
+ bar = 'β–ˆ' * int(per_class_acc[c] * 20)
401
+ print(f" {CIFAR_CLASSES[c]:<14s} {per_class_acc[c]:5.1%} {bar}")
402
+ print()
403
+
404
+ tag = "OMEGA PROCESSOR" if mode == 'omega' else "BASELINE"
405
+ print(f"\n{'=' * 70}")
406
+ print(f"{tag} COMPLETE")
407
+ print(f" Best test accuracy: {best_acc:.1%}")
408
+ print(f" Classifier params: {n_params:,}")
409
+ print(f" Random chance: 10.0%")
410
+ print(f"{'=' * 70}")
411
+
412
+ return classifier, best_acc
413
+
414
+
415
+ if __name__ == "__main__":
416
+ import sys
417
+ torch.set_float32_matmul_precision('high')
418
+
419
+ MODE = 'both' # 'omega', 'baseline', or 'both'
420
+ if len(sys.argv) > 1:
421
+ MODE = sys.argv[1]
422
+
423
+ results = {}
424
+
425
+ if MODE in ('omega', 'both'):
426
+ _, omega_acc = train_model(
427
+ mode='omega', epochs=30, batch_size=128,
428
+ lr=3e-4, d_model=128, n_heads=4, n_layers=4)
429
+ results['omega'] = omega_acc
430
+
431
+ if MODE in ('baseline', 'both'):
432
+ _, base_acc = train_model(
433
+ mode='baseline', epochs=30, batch_size=128,
434
+ lr=3e-4, d_model=128, n_heads=4, n_layers=4)
435
+ results['baseline'] = base_acc
436
+
437
+ if len(results) == 2:
438
+ print("\n" + "=" * 70)
439
+ print("HEAD-TO-HEAD COMPARISON")
440
+ print("=" * 70)
441
+ print(f" Omega Processor (Freckles features): {results['omega']:.1%}")
442
+ print(f" Baseline (raw patches): {results['baseline']:.1%}")
443
+ print(f" Delta: {results['omega'] - results['baseline']:+.1%}")
444
+ print(f" Random chance: 10.0%")
445
+ print("=" * 70)