geolip-SVAE / omega_processor_test_cifar10_noise_model.py
AbstractPhil's picture
Create omega_processor_test_cifar10_noise_model.py
a45b0f0 verified
raw
history blame
17.4 kB
"""
Omega Processor β€” CIFAR-10 Image Classification
==================================================
Freckles (trained on NOISE, frozen) β†’ SVD β†’ Geometric Features β†’ Transformer β†’ 10 classes
The ultimate test: can a noise-trained spectral decomposition
produce useful features for real image classification?
CIFAR-10 32Γ—32 β†’ bilinear resize to 64Γ—64 β†’ Freckles β†’ features β†’ classify
Usage:
python omega_cifar10.py
"""
import os
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
try:
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
except Exception:
pass
# ═══════════════════════════════════════════════════════════════
# GEOMETRIC FEATURE EXTRACTOR (same as omega_processor.py)
# ═══════════════════════════════════════════════════════════════
class GeometricFeatureExtractor(nn.Module):
def __init__(self, D=4, V=48):
super().__init__()
self.D = D
self.V = V
self.register_buffer('m_proj', torch.randn(V, 8) / math.sqrt(V))
def forward(self, svd_dict, gh, gw):
S = svd_dict['S']
S_orig = svd_dict['S_orig']
U = svd_dict['U']
Vt = svd_dict['Vt']
M = svd_dict['M']
B, N, D = S.shape
features = []
# Tier 1: Scalar (16 dims)
S_ratios = S[:, :, :-1] / (S[:, :, 1:] + 1e-8)
features.append(S_ratios)
S2 = S.pow(2)
energy = S2 / (S2.sum(-1, keepdim=True) + 1e-8)
features.append(energy)
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
erank = (-(p * p.log()).sum(-1, keepdim=True)).exp()
features.append(erank / D)
cond = (S[:, :, 0:1] / (S[:, :, -1:] + 1e-8))
features.append(cond / 10.0)
S_delta = S - S_orig
features.append(S_delta)
S_log = torch.log(S[:, :, :-1] + 1e-8) - torch.log(S[:, :, 1:] + 1e-8)
features.append(S_log)
# Tier 2: Relational (16 dims)
S_grid = S.reshape(B, gh, gw, D)
padded = F.pad(S_grid.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='reflect')
neighbor_sum = (padded[:, :, :-2, 1:-1] + padded[:, :, 2:, 1:-1] +
padded[:, :, 1:-1, :-2] + padded[:, :, 1:-1, 2:]) / 4
S_center = S_grid.permute(0, 3, 1, 2)
delta_card = (S_center - neighbor_sum).permute(0, 2, 3, 1).reshape(B, N, D)
features.append(delta_card)
neighbor_sq = (padded[:, :, :-2, 1:-1].pow(2) + padded[:, :, 2:, 1:-1].pow(2) +
padded[:, :, 1:-1, :-2].pow(2) + padded[:, :, 1:-1, 2:].pow(2)) / 4
neighbor_var = (neighbor_sq - neighbor_sum.pow(2)).clamp(min=0)
neighbor_std = neighbor_var.sqrt().permute(0, 2, 3, 1).reshape(B, N, D)
features.append(neighbor_std)
energy_grid = energy.reshape(B, gh, gw, D).permute(0, 3, 1, 2)
e_padded = F.pad(energy_grid, (1, 1, 1, 1), mode='reflect')
e_neighbor = (e_padded[:, :, :-2, 1:-1] + e_padded[:, :, 2:, 1:-1] +
e_padded[:, :, 1:-1, :-2] + e_padded[:, :, 1:-1, 2:]) / 4
e_delta = (energy_grid - e_neighbor).permute(0, 2, 3, 1).reshape(B, N, D)
features.append(e_delta)
rows = torch.arange(gh, device=S.device).float() / gh
cols = torch.arange(gw, device=S.device).float() / gw
row_grid = rows.unsqueeze(1).expand(gh, gw).reshape(1, N, 1).expand(B, -1, -1)
col_grid = cols.unsqueeze(0).expand(gh, gw).reshape(1, N, 1).expand(B, -1, -1)
features.append(torch.sin(row_grid * math.pi))
features.append(torch.cos(col_grid * math.pi))
features.append(torch.sin(row_grid * 2 * math.pi))
features.append(torch.cos(col_grid * 2 * math.pi))
# Tier 3: Basis (32 dims)
Vt_flat = Vt.reshape(B, N, D * D)
features.append(Vt_flat)
U_col_mean = U.mean(dim=2)
U_col_std = U.std(dim=2)
features.append(U_col_mean)
features.append(U_col_std)
M_sketch = torch.einsum('bnvd,vk->bnk', M, self.m_proj)
features.append(M_sketch)
return torch.cat(features, dim=-1)
# ═══════════════════════════════════════════════════════════════
# OMEGA TRANSFORMER CLASSIFIER
# ═══════════════════════════════════════════════════════════════
class OmegaTransformerClassifier(nn.Module):
def __init__(self, feat_dim=64, d_model=128, n_heads=4,
n_layers=4, n_classes=10, dropout=0.1, D=4, V=48):
super().__init__()
self.feat_extractor = GeometricFeatureExtractor(D=D, V=V)
self.input_proj = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Linear(feat_dim, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads,
dim_feedforward=d_model * 4,
dropout=dropout, batch_first=True,
activation='gelu',
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model, n_classes),
)
def forward(self, svd_dict, gh, gw):
features = self.feat_extractor(svd_dict, gh, gw)
B, N, F = features.shape
tokens = self.input_proj(features)
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, tokens], dim=1)
out = self.transformer(tokens)
return self.head(out[:, 0])
# ═══════════════════════════════════════════════════════════════
# RAW PATCH BASELINE
# ═══════════════════════════════════════════════════════════════
class RawPatchClassifier(nn.Module):
def __init__(self, patch_dim=48, d_model=128, n_heads=4,
n_layers=4, n_classes=10, dropout=0.1, n_patches=256):
super().__init__()
self.input_proj = nn.Sequential(
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
)
self.pos_enc = nn.Parameter(torch.randn(1, n_patches + 1, d_model) * 0.02)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads,
dim_feedforward=d_model * 4,
dropout=dropout, batch_first=True,
activation='gelu',
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model, n_classes),
)
def forward(self, images):
B, C, H, W = images.shape
ps = 4
gh, gw = H // ps, W // ps
N = gh * gw
patches = images.reshape(B, C, gh, ps, gw, ps)
patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, N, C * ps * ps)
tokens = self.input_proj(patches)
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, tokens], dim=1)
tokens = tokens + self.pos_enc[:, :N + 1]
out = self.transformer(tokens)
return self.head(out[:, 0])
# ═══════════════════════════════════════════════════════════════
# CIFAR-10 DATASET
# ═══════════════════════════════════════════════════════════════
CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
IMG_MEAN = (0.4914, 0.4822, 0.4465)
IMG_STD = (0.2470, 0.2435, 0.2616)
def get_cifar10_loaders(batch_size=128, img_size=64):
"""Load CIFAR-10, resize to img_size, normalize."""
import torchvision
import torchvision.transforms as T
transform_train = T.Compose([
T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(IMG_MEAN, IMG_STD),
])
transform_test = T.Compose([
T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize(IMG_MEAN, IMG_STD),
])
train_ds = torchvision.datasets.CIFAR10(
root='/content/data', train=True, download=True, transform=transform_train)
test_ds = torchvision.datasets.CIFAR10(
root='/content/data', train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
return train_loader, test_loader
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train_model(mode='omega', epochs=30, batch_size=128, lr=3e-4,
d_model=128, n_heads=4, n_layers=4, img_size=64,
device='cuda'):
"""
mode: 'omega' (Freckles + features) or 'baseline' (raw patches)
"""
device = torch.device(device if torch.cuda.is_available() else 'cpu')
print("\n" + "=" * 70)
if mode == 'omega':
print("OMEGA PROCESSOR β€” CIFAR-10 (Freckles features)")
else:
print("BASELINE β€” CIFAR-10 (Raw patches, no Freckles)")
print("=" * 70)
ps = 4
gh, gw = img_size // ps, img_size // ps
n_patches = gh * gw
# Load Freckles for omega mode
freckles = None
if mode == 'omega':
from geolip_svae import load_model
freckles, f_cfg = load_model(hf_version='v40_freckles_noise', device=device)
freckles.eval()
for p in freckles.parameters():
p.requires_grad = False
print(f" Freckles: {sum(p.numel() for p in freckles.parameters()):,} params (frozen)")
# Determine feature dim
with torch.no_grad():
dummy = torch.randn(1, 3, img_size, img_size).to(device)
dummy_out = freckles(dummy)
feat_ext = GeometricFeatureExtractor(D=f_cfg['D'], V=f_cfg['V']).to(device)
feat_dim = feat_ext(dummy_out['svd'], gh, gw).shape[-1]
del feat_ext
print(f" Feature dim: {feat_dim}")
classifier = OmegaTransformerClassifier(
feat_dim=feat_dim, d_model=d_model, n_heads=n_heads,
n_layers=n_layers, n_classes=10, D=f_cfg['D'], V=f_cfg['V'],
).to(device)
else:
classifier = RawPatchClassifier(
patch_dim=3 * ps * ps, d_model=d_model, n_heads=n_heads,
n_layers=n_layers, n_classes=10, n_patches=n_patches,
).to(device)
n_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
print(f" Classifier: {n_params:,} params")
print(f" Architecture: d_model={d_model}, heads={n_heads}, layers={n_layers}")
print(f" CIFAR-10: 50K train, 10K test, {img_size}Γ—{img_size}")
print(f" Batch: {batch_size}, lr={lr}, epochs={epochs}")
print("=" * 70)
train_loader, test_loader = get_cifar10_loaders(batch_size, img_size)
opt = torch.optim.Adam(classifier.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_acc = 0
for epoch in range(1, epochs + 1):
classifier.train()
total_loss, correct, total = 0, 0, 0
t0 = time.time()
pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
bar_format='{l_bar}{bar:20}{r_bar}')
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
if mode == 'omega':
with torch.no_grad():
out = freckles(images)
logits = classifier(out['svd'], gh, gw)
else:
logits = classifier(images)
loss = F.cross_entropy(logits, labels)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
opt.step()
total_loss += loss.item() * len(labels)
correct += (logits.argmax(-1) == labels).sum().item()
total += len(labels)
pbar.set_postfix_str(f"loss={loss.item():.4f} acc={correct/total:.1%}")
sched.step()
train_acc = correct / total
train_loss = total_loss / total
# Test
classifier.eval()
test_correct, test_total = 0, 0
per_class_correct = torch.zeros(10)
per_class_total = torch.zeros(10)
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
if mode == 'omega':
out = freckles(images)
logits = classifier(out['svd'], gh, gw)
else:
logits = classifier(images)
preds = logits.argmax(-1)
test_correct += (preds == labels).sum().item()
test_total += len(labels)
for c in range(10):
mask = labels == c
per_class_correct[c] += (preds[mask] == labels[mask]).sum().item()
per_class_total[c] += mask.sum().item()
test_acc = test_correct / test_total
epoch_time = time.time() - t0
per_class_acc = per_class_correct / (per_class_total + 1e-8)
worst_class = per_class_acc.argmin().item()
best_class = per_class_acc.argmax().item()
print(f" ep{epoch:3d} | loss={train_loss:.4f} train={train_acc:.1%} "
f"test={test_acc:.1%} | best={CIFAR_CLASSES[best_class]}={per_class_acc[best_class]:.0%} "
f"worst={CIFAR_CLASSES[worst_class]}={per_class_acc[worst_class]:.0%} | {epoch_time:.0f}s")
if test_acc > best_acc:
best_acc = test_acc
if epoch % 5 == 0 or epoch == 1 or epoch == epochs:
print(f"\n {'class':<14s} {'acc':>6s}")
print(f" {'-'*22}")
for c in range(10):
bar = 'β–ˆ' * int(per_class_acc[c] * 20)
print(f" {CIFAR_CLASSES[c]:<14s} {per_class_acc[c]:5.1%} {bar}")
print()
tag = "OMEGA PROCESSOR" if mode == 'omega' else "BASELINE"
print(f"\n{'=' * 70}")
print(f"{tag} COMPLETE")
print(f" Best test accuracy: {best_acc:.1%}")
print(f" Classifier params: {n_params:,}")
print(f" Random chance: 10.0%")
print(f"{'=' * 70}")
return classifier, best_acc
if __name__ == "__main__":
import sys
torch.set_float32_matmul_precision('high')
MODE = 'both' # 'omega', 'baseline', or 'both'
if len(sys.argv) > 1:
MODE = sys.argv[1]
results = {}
if MODE in ('omega', 'both'):
_, omega_acc = train_model(
mode='omega', epochs=30, batch_size=128,
lr=3e-4, d_model=128, n_heads=4, n_layers=4)
results['omega'] = omega_acc
if MODE in ('baseline', 'both'):
_, base_acc = train_model(
mode='baseline', epochs=30, batch_size=128,
lr=3e-4, d_model=128, n_heads=4, n_layers=4)
results['baseline'] = base_acc
if len(results) == 2:
print("\n" + "=" * 70)
print("HEAD-TO-HEAD COMPARISON")
print("=" * 70)
print(f" Omega Processor (Freckles features): {results['omega']:.1%}")
print(f" Baseline (raw patches): {results['baseline']:.1%}")
print(f" Delta: {results['omega'] - results['baseline']:+.1%}")
print(f" Random chance: 10.0%")
print("=" * 70)