""" Baseline — Raw Patch Transformer Noise Classifier ==================================================== Same transformer architecture as Omega Processor. No Freckles. No SVD. No geometric features. Raw 4×4 patches (48 dims) → Linear → Transformer → 16 classes. The control experiment, no SVD trained models to assist. Usage: python omega_baseline.py """ import os import math import time import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm try: from google.colab import userdata os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') from huggingface_hub import login login(token=os.environ["HF_TOKEN"]) except Exception: pass # ═══════════════════════════════════════════════════════════════ # NOISE GENERATORS (same as omega_processor.py) # ═══════════════════════════════════════════════════════════════ def _pink(shape): w = torch.randn(shape) S = torch.fft.rfft2(w) h, ww = shape[-2], shape[-1] fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) return torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) def _brown(shape): w = torch.randn(shape) S = torch.fft.rfft2(w) h, ww = shape[-2], shape[-1] fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) return torch.fft.irfft2(S / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) def _gen_noise(t, s, rng): if t == 0: return torch.randn(3, s, s) elif t == 1: return torch.rand(3, s, s) * 2 - 1 elif t == 2: return (torch.rand(3, s, s) - 0.5) * 4 elif t == 3: lam = rng.uniform(0.5, 20.0) return torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0 elif t == 4: img = _pink((3, s, s)); return img / (img.std() + 1e-8) elif t == 5: img = _brown((3, s, s)); return img / (img.std() + 1e-8) elif t == 6: return torch.where(torch.rand(3, s, s) > 0.5, torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1 elif t == 7: return torch.randn(3, s, s) * (torch.rand(3, s, s) > 0.9).float() * 3 elif t == 8: b = rng.randint(2, max(3, s // 4)) sm = torch.randn(3, s // b + 1, s // b + 1) return F.interpolate(sm.unsqueeze(0), size=s, mode='nearest').squeeze(0) elif t == 9: gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s) gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s) a = rng.uniform(0, 2 * math.pi) return (math.cos(a) * gx + math.sin(a) * gy).unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5 elif t == 10: cs = rng.randint(2, max(3, s // 4)) cy = torch.arange(s) // cs; cx = torch.arange(s) // cs return ((cy.unsqueeze(1) + cx.unsqueeze(0)) % 2).float().unsqueeze(0).expand(3, -1, -1) * 2 - 1 + torch.randn(3, s, s) * 0.3 elif t == 11: alpha = rng.uniform(0.2, 0.8) return alpha * torch.randn(3, s, s) + (1 - alpha) * (torch.rand(3, s, s) * 2 - 1) elif t == 12: img = torch.zeros(3, s, s); h2 = s // 2; w2 = s // 2 img[:, :h2, :w2] = torch.randn(3, h2, w2) img[:, :h2, w2:s] = torch.rand(3, h2, s - w2) * 2 - 1 img[:, h2:s, :w2] = _pink((3, s - h2, w2)) / 2 img[:, h2:s, w2:s] = torch.where(torch.rand(3, s - h2, s - w2) > 0.5, torch.ones(3, s - h2, s - w2), -torch.ones(3, s - h2, s - w2)) return img elif t == 13: return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3) elif t == 14: return torch.empty(3, s, s).exponential_(1.0) - 1.0 elif t == 15: u = torch.rand(3, s, s) - 0.5 return -torch.sign(u) * torch.log1p(-2 * u.abs()) return torch.randn(3, s, s) NOISE_NAMES = { 0: 'gaussian', 1: 'uniform', 2: 'uniform_sc', 3: 'poisson', 4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse', 8: 'block', 9: 'gradient', 10: 'checker', 11: 'mixed', 12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace', } class NoiseDataset(torch.utils.data.Dataset): def __init__(self, size=160000, img_size=64): self.size = size self.img_size = img_size self._rng = np.random.RandomState(42) self._call_count = 0 def __len__(self): return self.size def __getitem__(self, idx): self._call_count += 1 if self._call_count % 1000 == 0: self._rng = np.random.RandomState(int.from_bytes(os.urandom(4), 'big')) torch.manual_seed(int.from_bytes(os.urandom(4), 'big')) noise_type = idx % 16 img = _gen_noise(noise_type, self.img_size, self._rng).clamp(-4, 4) return img.float(), noise_type # ═══════════════════════════════════════════════════════════════ # RAW PATCH TRANSFORMER CLASSIFIER # ═══════════════════════════════════════════════════════════════ class RawPatchClassifier(nn.Module): """Same transformer, raw patches instead of geometric features. Raw 4×4 patches (3×4×4 = 48 dims) → Linear → Transformer → 16 classes. Same d_model, n_heads, n_layers as OmegaTransformerClassifier. """ def __init__(self, patch_dim=48, d_model=128, n_heads=4, n_layers=4, n_classes=16, dropout=0.1): super().__init__() # Project raw patches to transformer dim self.input_proj = nn.Sequential( nn.LayerNorm(patch_dim), nn.Linear(patch_dim, d_model), nn.GELU(), nn.Linear(d_model, d_model), ) # Learnable position encoding self.pos_enc = nn.Parameter(torch.randn(1, 256 + 1, d_model) * 0.02) # CLS token self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True, activation='gelu', ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) # Classification head self.head = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model, n_classes), ) def forward(self, images): B, C, H, W = images.shape ps = 4 gh, gw = H // ps, W // ps N = gh * gw # Extract raw patches patches = images.reshape(B, C, gh, ps, gw, ps) patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, N, C * ps * ps) # Project tokens = self.input_proj(patches) # CLS + position cls = self.cls_token.expand(B, -1, -1) tokens = torch.cat([cls, tokens], dim=1) tokens = tokens + self.pos_enc[:, :N + 1] # Transformer out = self.transformer(tokens) # CLS → classify return self.head(out[:, 0]) # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train(epochs=20, batch_size=128, lr=3e-4, d_model=128, n_heads=4, n_layers=4, img_size=64, device='cuda'): device = torch.device(device if torch.cuda.is_available() else 'cpu') print("\n" + "=" * 70) print("BASELINE — Raw Patch Transformer (No Freckles, No SVD)") print("=" * 70) patch_dim = 3 * 4 * 4 # 48 classifier = RawPatchClassifier( patch_dim=patch_dim, d_model=d_model, n_heads=n_heads, n_layers=n_layers, n_classes=16, ).to(device) n_params = sum(p.numel() for p in classifier.parameters()) print(f" Classifier: {n_params:,} params") print(f" Input: raw 4×4 patches ({patch_dim} dims)") print(f" Architecture: d_model={d_model}, heads={n_heads}, layers={n_layers}") print(f" Batch: {batch_size}, lr={lr}, epochs={epochs}") print(f" No expert. No SVD. No geometric features.") print("=" * 70) train_ds = NoiseDataset(size=160000, img_size=img_size) val_ds = NoiseDataset(size=16000, img_size=img_size) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) opt = torch.optim.Adam(classifier.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best_acc = 0 for epoch in range(1, epochs + 1): classifier.train() total_loss, correct, total = 0, 0, 0 t0 = time.time() pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}", bar_format='{l_bar}{bar:20}{r_bar}') for images, labels in pbar: images = images.to(device) labels = labels.to(device) logits = classifier(images) loss = F.cross_entropy(logits, labels) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0) opt.step() total_loss += loss.item() * len(labels) correct += (logits.argmax(-1) == labels).sum().item() total += len(labels) pbar.set_postfix_str(f"loss={loss.item():.4f} acc={correct/total:.1%}") sched.step() train_acc = correct / total train_loss = total_loss / total # Eval classifier.eval() val_correct, val_total = 0, 0 per_class_correct = torch.zeros(16) per_class_total = torch.zeros(16) with torch.no_grad(): for images, labels in val_loader: images = images.to(device) labels = labels.to(device) logits = classifier(images) preds = logits.argmax(-1) val_correct += (preds == labels).sum().item() val_total += len(labels) for c in range(16): mask = labels == c per_class_correct[c] += (preds[mask] == labels[mask]).sum().item() per_class_total[c] += mask.sum().item() val_acc = val_correct / val_total epoch_time = time.time() - t0 per_class_acc = per_class_correct / (per_class_total + 1e-8) worst_class = per_class_acc.argmin().item() best_class = per_class_acc.argmax().item() print(f" ep{epoch:3d} | loss={train_loss:.4f} train={train_acc:.1%} " f"val={val_acc:.1%} | best={NOISE_NAMES[best_class]}={per_class_acc[best_class]:.0%} " f"worst={NOISE_NAMES[worst_class]}={per_class_acc[worst_class]:.0%} | {epoch_time:.0f}s") if val_acc > best_acc: best_acc = val_acc if epoch % 5 == 0 or epoch == 1: print(f"\n {'type':<14s} {'acc':>6s}") print(f" {'-'*22}") for c in range(16): bar = '█' * int(per_class_acc[c] * 20) print(f" {NOISE_NAMES[c]:<14s} {per_class_acc[c]:5.1%} {bar}") print() print(f"\n{'=' * 70}") print(f"BASELINE COMPLETE") print(f" Best val accuracy: {best_acc:.1%}") print(f" Params: {n_params:,}") print(f" Random chance: {1/16:.1%}") print(f"{'=' * 70}") return classifier if __name__ == "__main__": torch.set_float32_matmul_precision('high') train( epochs=20, batch_size=128, lr=3e-4, d_model=128, n_heads=4, n_layers=4, img_size=64, )