geolip-SVAE / omega_processer_transformer_only_ablation.py
AbstractPhil's picture
Create omega_processer_transformer_only_ablation.py
677a4b4 verified
raw
history blame
12.6 kB
"""
Baseline β€” Raw Patch Transformer Noise Classifier
====================================================
Same transformer architecture as Omega Processor.
No Freckles. No SVD. No geometric features.
Raw 4Γ—4 patches (48 dims) β†’ Linear β†’ Transformer β†’ 16 classes.
The control experiment, no SVD trained models to assist.
Usage:
python omega_baseline.py
"""
import os
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
try:
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
except Exception:
pass
# ═══════════════════════════════════════════════════════════════
# NOISE GENERATORS (same as omega_processor.py)
# ═══════════════════════════════════════════════════════════════
def _pink(shape):
w = torch.randn(shape)
S = torch.fft.rfft2(w)
h, ww = shape[-2], shape[-1]
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
return torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
def _brown(shape):
w = torch.randn(shape)
S = torch.fft.rfft2(w)
h, ww = shape[-2], shape[-1]
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
return torch.fft.irfft2(S / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
def _gen_noise(t, s, rng):
if t == 0: return torch.randn(3, s, s)
elif t == 1: return torch.rand(3, s, s) * 2 - 1
elif t == 2: return (torch.rand(3, s, s) - 0.5) * 4
elif t == 3:
lam = rng.uniform(0.5, 20.0)
return torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0
elif t == 4:
img = _pink((3, s, s)); return img / (img.std() + 1e-8)
elif t == 5:
img = _brown((3, s, s)); return img / (img.std() + 1e-8)
elif t == 6:
return torch.where(torch.rand(3, s, s) > 0.5,
torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1
elif t == 7:
return torch.randn(3, s, s) * (torch.rand(3, s, s) > 0.9).float() * 3
elif t == 8:
b = rng.randint(2, max(3, s // 4))
sm = torch.randn(3, s // b + 1, s // b + 1)
return F.interpolate(sm.unsqueeze(0), size=s, mode='nearest').squeeze(0)
elif t == 9:
gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s)
gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s)
a = rng.uniform(0, 2 * math.pi)
return (math.cos(a) * gx + math.sin(a) * gy).unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5
elif t == 10:
cs = rng.randint(2, max(3, s // 4))
cy = torch.arange(s) // cs; cx = torch.arange(s) // cs
return ((cy.unsqueeze(1) + cx.unsqueeze(0)) % 2).float().unsqueeze(0).expand(3, -1, -1) * 2 - 1 + torch.randn(3, s, s) * 0.3
elif t == 11:
alpha = rng.uniform(0.2, 0.8)
return alpha * torch.randn(3, s, s) + (1 - alpha) * (torch.rand(3, s, s) * 2 - 1)
elif t == 12:
img = torch.zeros(3, s, s); h2 = s // 2; w2 = s // 2
img[:, :h2, :w2] = torch.randn(3, h2, w2)
img[:, :h2, w2:s] = torch.rand(3, h2, s - w2) * 2 - 1
img[:, h2:s, :w2] = _pink((3, s - h2, w2)) / 2
img[:, h2:s, w2:s] = torch.where(torch.rand(3, s - h2, s - w2) > 0.5,
torch.ones(3, s - h2, s - w2), -torch.ones(3, s - h2, s - w2))
return img
elif t == 13:
return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3)
elif t == 14:
return torch.empty(3, s, s).exponential_(1.0) - 1.0
elif t == 15:
u = torch.rand(3, s, s) - 0.5
return -torch.sign(u) * torch.log1p(-2 * u.abs())
return torch.randn(3, s, s)
NOISE_NAMES = {
0: 'gaussian', 1: 'uniform', 2: 'uniform_sc', 3: 'poisson',
4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse',
8: 'block', 9: 'gradient', 10: 'checker', 11: 'mixed',
12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace',
}
class NoiseDataset(torch.utils.data.Dataset):
def __init__(self, size=160000, img_size=64):
self.size = size
self.img_size = img_size
self._rng = np.random.RandomState(42)
self._call_count = 0
def __len__(self):
return self.size
def __getitem__(self, idx):
self._call_count += 1
if self._call_count % 1000 == 0:
self._rng = np.random.RandomState(int.from_bytes(os.urandom(4), 'big'))
torch.manual_seed(int.from_bytes(os.urandom(4), 'big'))
noise_type = idx % 16
img = _gen_noise(noise_type, self.img_size, self._rng).clamp(-4, 4)
return img.float(), noise_type
# ═══════════════════════════════════════════════════════════════
# RAW PATCH TRANSFORMER CLASSIFIER
# ═══════════════════════════════════════════════════════════════
class RawPatchClassifier(nn.Module):
"""Same transformer, raw patches instead of geometric features.
Raw 4Γ—4 patches (3Γ—4Γ—4 = 48 dims) β†’ Linear β†’ Transformer β†’ 16 classes.
Same d_model, n_heads, n_layers as OmegaTransformerClassifier.
"""
def __init__(self, patch_dim=48, d_model=128, n_heads=4,
n_layers=4, n_classes=16, dropout=0.1):
super().__init__()
# Project raw patches to transformer dim
self.input_proj = nn.Sequential(
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
)
# Learnable position encoding
self.pos_enc = nn.Parameter(torch.randn(1, 256 + 1, d_model) * 0.02)
# CLS token
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads,
dim_feedforward=d_model * 4,
dropout=dropout, batch_first=True,
activation='gelu',
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# Classification head
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model, n_classes),
)
def forward(self, images):
B, C, H, W = images.shape
ps = 4
gh, gw = H // ps, W // ps
N = gh * gw
# Extract raw patches
patches = images.reshape(B, C, gh, ps, gw, ps)
patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, N, C * ps * ps)
# Project
tokens = self.input_proj(patches)
# CLS + position
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, tokens], dim=1)
tokens = tokens + self.pos_enc[:, :N + 1]
# Transformer
out = self.transformer(tokens)
# CLS β†’ classify
return self.head(out[:, 0])
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train(epochs=20, batch_size=128, lr=3e-4, d_model=128,
n_heads=4, n_layers=4, img_size=64, device='cuda'):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
print("\n" + "=" * 70)
print("BASELINE β€” Raw Patch Transformer (No Freckles, No SVD)")
print("=" * 70)
patch_dim = 3 * 4 * 4 # 48
classifier = RawPatchClassifier(
patch_dim=patch_dim, d_model=d_model, n_heads=n_heads,
n_layers=n_layers, n_classes=16,
).to(device)
n_params = sum(p.numel() for p in classifier.parameters())
print(f" Classifier: {n_params:,} params")
print(f" Input: raw 4Γ—4 patches ({patch_dim} dims)")
print(f" Architecture: d_model={d_model}, heads={n_heads}, layers={n_layers}")
print(f" Batch: {batch_size}, lr={lr}, epochs={epochs}")
print(f" No expert. No SVD. No geometric features.")
print("=" * 70)
train_ds = NoiseDataset(size=160000, img_size=img_size)
val_ds = NoiseDataset(size=16000, img_size=img_size)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
opt = torch.optim.Adam(classifier.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_acc = 0
for epoch in range(1, epochs + 1):
classifier.train()
total_loss, correct, total = 0, 0, 0
t0 = time.time()
pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
bar_format='{l_bar}{bar:20}{r_bar}')
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
logits = classifier(images)
loss = F.cross_entropy(logits, labels)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
opt.step()
total_loss += loss.item() * len(labels)
correct += (logits.argmax(-1) == labels).sum().item()
total += len(labels)
pbar.set_postfix_str(f"loss={loss.item():.4f} acc={correct/total:.1%}")
sched.step()
train_acc = correct / total
train_loss = total_loss / total
# Eval
classifier.eval()
val_correct, val_total = 0, 0
per_class_correct = torch.zeros(16)
per_class_total = torch.zeros(16)
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
logits = classifier(images)
preds = logits.argmax(-1)
val_correct += (preds == labels).sum().item()
val_total += len(labels)
for c in range(16):
mask = labels == c
per_class_correct[c] += (preds[mask] == labels[mask]).sum().item()
per_class_total[c] += mask.sum().item()
val_acc = val_correct / val_total
epoch_time = time.time() - t0
per_class_acc = per_class_correct / (per_class_total + 1e-8)
worst_class = per_class_acc.argmin().item()
best_class = per_class_acc.argmax().item()
print(f" ep{epoch:3d} | loss={train_loss:.4f} train={train_acc:.1%} "
f"val={val_acc:.1%} | best={NOISE_NAMES[best_class]}={per_class_acc[best_class]:.0%} "
f"worst={NOISE_NAMES[worst_class]}={per_class_acc[worst_class]:.0%} | {epoch_time:.0f}s")
if val_acc > best_acc:
best_acc = val_acc
if epoch % 5 == 0 or epoch == 1:
print(f"\n {'type':<14s} {'acc':>6s}")
print(f" {'-'*22}")
for c in range(16):
bar = 'β–ˆ' * int(per_class_acc[c] * 20)
print(f" {NOISE_NAMES[c]:<14s} {per_class_acc[c]:5.1%} {bar}")
print()
print(f"\n{'=' * 70}")
print(f"BASELINE COMPLETE")
print(f" Best val accuracy: {best_acc:.1%}")
print(f" Params: {n_params:,}")
print(f" Random chance: {1/16:.1%}")
print(f"{'=' * 70}")
return classifier
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
train(
epochs=20,
batch_size=128,
lr=3e-4,
d_model=128,
n_heads=4,
n_layers=4,
img_size=64,
)