| """ |
| Omega Processor v2 β CIFAR-10 with Encoder Hidden States |
| ========================================================== |
| Freckles (frozen) β grab encoder hidden states (384-dim) |
| + SVD geometric features (64-dim) |
| = 448-dim per patch β Transformer β classify |
| |
| The encoder hidden state is the FULL pre-bottleneck representation. |
| The geometric features are the post-bottleneck spectral structure. |
| Together: understanding + structure. |
| |
| Tests show that compressing this information that comes out of here AT ALL completely destroys it. |
| The v3 MUST be unabridged. |
| |
| |
| Usage: |
| python omega_cifar10_v2.py |
| """ |
|
|
| import os |
| import math |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
|
|
| try: |
| from google.colab import userdata |
| os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"]) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| class FrecklesWithHidden: |
| """Wrapper around frozen Freckles that captures encoder hidden states.""" |
|
|
| def __init__(self, freckles): |
| self.model = freckles |
| self._hidden = None |
| self._hook = None |
| self._attach() |
|
|
| def _attach(self): |
| |
| last_block = self.model.enc_blocks[-1] |
| def hook(module, inp, out): |
| self._hidden = out.detach() |
| self._hook = last_block.register_forward_hook(hook) |
|
|
| @torch.no_grad() |
| def __call__(self, images): |
| self._hidden = None |
| out = self.model(images) |
| |
| B = images.shape[0] |
| N = out['svd']['S'].shape[1] |
| hidden = self._hidden.reshape(B, N, -1) |
| return out, hidden |
|
|
| def remove(self): |
| if self._hook: |
| self._hook.remove() |
|
|
|
|
| |
| |
| |
|
|
| class GeometricFeatureExtractor(nn.Module): |
| def __init__(self, D=4, V=48): |
| super().__init__() |
| self.D = D |
| self.V = V |
| self.register_buffer('m_proj', torch.randn(V, 8) / math.sqrt(V)) |
|
|
| def forward(self, svd_dict, gh, gw): |
| S = svd_dict['S'] |
| S_orig = svd_dict['S_orig'] |
| U = svd_dict['U'] |
| Vt = svd_dict['Vt'] |
| M = svd_dict['M'] |
| B, N, D = S.shape |
| features = [] |
|
|
| |
| features.append(S[:, :, :-1] / (S[:, :, 1:] + 1e-8)) |
| S2 = S.pow(2) |
| energy = S2 / (S2.sum(-1, keepdim=True) + 1e-8) |
| features.append(energy) |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| features.append((-(p * p.log()).sum(-1, keepdim=True)).exp() / D) |
| features.append(S[:, :, 0:1] / (S[:, :, -1:] + 1e-8) / 10.0) |
| features.append(S - S_orig) |
| features.append(torch.log(S[:, :, :-1] + 1e-8) - torch.log(S[:, :, 1:] + 1e-8)) |
|
|
| |
| S_grid = S.reshape(B, gh, gw, D) |
| padded = F.pad(S_grid.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='reflect') |
| neighbor_sum = (padded[:, :, :-2, 1:-1] + padded[:, :, 2:, 1:-1] + |
| padded[:, :, 1:-1, :-2] + padded[:, :, 1:-1, 2:]) / 4 |
| S_center = S_grid.permute(0, 3, 1, 2) |
| features.append((S_center - neighbor_sum).permute(0, 2, 3, 1).reshape(B, N, D)) |
| neighbor_sq = (padded[:, :, :-2, 1:-1].pow(2) + padded[:, :, 2:, 1:-1].pow(2) + |
| padded[:, :, 1:-1, :-2].pow(2) + padded[:, :, 1:-1, 2:].pow(2)) / 4 |
| neighbor_var = (neighbor_sq - neighbor_sum.pow(2)).clamp(min=0) |
| features.append(neighbor_var.sqrt().permute(0, 2, 3, 1).reshape(B, N, D)) |
| energy_grid = energy.reshape(B, gh, gw, D).permute(0, 3, 1, 2) |
| e_padded = F.pad(energy_grid, (1, 1, 1, 1), mode='reflect') |
| e_neighbor = (e_padded[:, :, :-2, 1:-1] + e_padded[:, :, 2:, 1:-1] + |
| e_padded[:, :, 1:-1, :-2] + e_padded[:, :, 1:-1, 2:]) / 4 |
| features.append((energy_grid - e_neighbor).permute(0, 2, 3, 1).reshape(B, N, D)) |
| rows = torch.arange(gh, device=S.device).float() / gh |
| cols = torch.arange(gw, device=S.device).float() / gw |
| row_grid = rows.unsqueeze(1).expand(gh, gw).reshape(1, N, 1).expand(B, -1, -1) |
| col_grid = cols.unsqueeze(0).expand(gh, gw).reshape(1, N, 1).expand(B, -1, -1) |
| features.append(torch.sin(row_grid * math.pi)) |
| features.append(torch.cos(col_grid * math.pi)) |
| features.append(torch.sin(row_grid * 2 * math.pi)) |
| features.append(torch.cos(col_grid * 2 * math.pi)) |
|
|
| |
| features.append(Vt.reshape(B, N, D * D)) |
| features.append(U.mean(dim=2)) |
| features.append(U.std(dim=2)) |
| features.append(torch.einsum('bnvd,vk->bnk', M, self.m_proj)) |
|
|
| return torch.cat(features, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class HierarchicalOmegaClassifier(nn.Module): |
| """Transformer classifier with dual input streams. |
| |
| Stream A: encoder hidden states (384-dim) β rich pre-bottleneck features |
| Stream B: geometric features (64-dim) β spectral post-bottleneck structure |
| |
| Hierarchy: |
| Each stream gets its own projection to d_model. |
| Fused via learned gating: Ξ± * hidden_proj + (1-Ξ±) * geo_proj |
| Then standard transformer encoder with CLS token. |
| """ |
|
|
| def __init__(self, hidden_dim=384, geo_dim=64, d_model=128, n_heads=4, |
| n_layers=4, n_classes=10, dropout=0.1, D=4, V=48): |
| super().__init__() |
| self.feat_extractor = GeometricFeatureExtractor(D=D, V=V) |
|
|
| |
| self.hidden_proj = nn.Sequential( |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| |
| self.geo_proj = nn.Sequential( |
| nn.LayerNorm(geo_dim), |
| nn.Linear(geo_dim, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| |
| self.gate = nn.Sequential( |
| nn.Linear(d_model * 2, d_model), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, |
| dim_feedforward=d_model * 4, |
| dropout=dropout, batch_first=True, |
| activation='gelu', |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
|
|
| |
| self.head = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_model, n_classes), |
| ) |
|
|
| def forward(self, svd_dict, hidden, gh, gw): |
| """ |
| Args: |
| svd_dict: from frozen Freckles |
| hidden: (B, N, 384) encoder hidden states |
| gh, gw: grid dims |
| """ |
| |
| h_proj = self.hidden_proj(hidden) |
|
|
| |
| geo_feats = self.feat_extractor(svd_dict, gh, gw) |
| g_proj = self.geo_proj(geo_feats) |
|
|
| |
| combined = torch.cat([h_proj, g_proj], dim=-1) |
| alpha = self.gate(combined) |
| fused = alpha * h_proj + (1 - alpha) * g_proj |
|
|
| |
| B = fused.shape[0] |
| cls = self.cls_token.expand(B, -1, -1) |
| tokens = torch.cat([cls, fused], dim=1) |
| out = self.transformer(tokens) |
|
|
| return self.head(out[:, 0]) |
|
|
|
|
| |
| |
| |
|
|
| class RawPatchClassifier(nn.Module): |
| def __init__(self, patch_dim=48, d_model=128, n_heads=4, |
| n_layers=4, n_classes=10, dropout=0.1, n_patches=256): |
| super().__init__() |
| self.input_proj = nn.Sequential( |
| nn.LayerNorm(patch_dim), |
| nn.Linear(patch_dim, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| self.pos_enc = nn.Parameter(torch.randn(1, n_patches + 1, d_model) * 0.02) |
| self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, |
| dim_feedforward=d_model * 4, |
| dropout=dropout, batch_first=True, activation='gelu') |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
| self.head = nn.Sequential( |
| nn.LayerNorm(d_model), nn.Linear(d_model, d_model), |
| nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model, n_classes)) |
|
|
| def forward(self, images): |
| B, C, H, W = images.shape |
| ps = 4; gh, gw = H // ps, W // ps; N = gh * gw |
| patches = images.reshape(B, C, gh, ps, gw, ps).permute(0, 2, 4, 1, 3, 5).reshape(B, N, C * ps * ps) |
| tokens = self.input_proj(patches) |
| cls = self.cls_token.expand(B, -1, -1) |
| tokens = torch.cat([cls, tokens], dim=1) + self.pos_enc[:, :N + 1] |
| return self.head(self.transformer(tokens)[:, 0]) |
|
|
|
|
| |
| |
| |
|
|
| CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| def get_cifar10_loaders(batch_size=128, img_size=64): |
| import torchvision |
| import torchvision.transforms as T |
| transform_train = T.Compose([ |
| T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR), |
| T.RandomHorizontalFlip(), T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) |
| transform_test = T.Compose([ |
| T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR), |
| T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) |
| train_ds = torchvision.datasets.CIFAR10(root='/content/data', train=True, download=True, transform=transform_train) |
| test_ds = torchvision.datasets.CIFAR10(root='/content/data', train=False, download=True, transform=transform_test) |
| return (torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True), |
| torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) |
|
|
|
|
| |
| |
| |
|
|
| def train_model(mode='omega', epochs=30, batch_size=128, lr=3e-4, |
| d_model=128, n_heads=4, n_layers=4, img_size=64, |
| device='cuda'): |
|
|
| device = torch.device(device if torch.cuda.is_available() else 'cpu') |
| ps = 4 |
| gh, gw = img_size // ps, img_size // ps |
|
|
| print("\n" + "=" * 70) |
| if mode == 'omega': |
| print("OMEGA PROCESSOR v2 β CIFAR-10 (Hidden + Geometric features)") |
| else: |
| print("BASELINE β CIFAR-10 (Raw patches, no Freckles)") |
| print("=" * 70) |
|
|
| freckles_wrapper = None |
| if mode == 'omega': |
| from geolip_svae import load_model |
| freckles, f_cfg = load_model(hf_version='v40_freckles_noise', device=device) |
| freckles.eval() |
| for p in freckles.parameters(): |
| p.requires_grad = False |
| freckles_wrapper = FrecklesWithHidden(freckles) |
| print(f" Freckles: {sum(p.numel() for p in freckles.parameters()):,} params (frozen)") |
|
|
| |
| with torch.no_grad(): |
| dummy = torch.randn(1, 3, img_size, img_size).to(device) |
| dummy_out, dummy_hidden = freckles_wrapper(dummy) |
| feat_ext = GeometricFeatureExtractor(D=f_cfg['D'], V=f_cfg['V']).to(device) |
| geo_dim = feat_ext(dummy_out['svd'], gh, gw).shape[-1] |
| hidden_dim = dummy_hidden.shape[-1] |
| del feat_ext |
| print(f" Encoder hidden dim: {hidden_dim}") |
| print(f" Geometric feature dim: {geo_dim}") |
| print(f" Combined: {hidden_dim} + {geo_dim} = {hidden_dim + geo_dim} per patch") |
|
|
| classifier = HierarchicalOmegaClassifier( |
| hidden_dim=hidden_dim, geo_dim=geo_dim, |
| d_model=d_model, n_heads=n_heads, n_layers=n_layers, |
| n_classes=10, D=f_cfg['D'], V=f_cfg['V'], |
| ).to(device) |
| else: |
| classifier = RawPatchClassifier( |
| patch_dim=3 * ps * ps, d_model=d_model, n_heads=n_heads, |
| n_layers=n_layers, n_classes=10, n_patches=gh * gw, |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad) |
| print(f" Classifier: {n_params:,} params") |
| print(f" Architecture: d_model={d_model}, heads={n_heads}, layers={n_layers}") |
| print(f" CIFAR-10: 50K train, 10K test, {img_size}Γ{img_size}") |
| print(f" Batch: {batch_size}, lr={lr}, epochs={epochs}") |
| print("=" * 70) |
|
|
| train_loader, test_loader = get_cifar10_loaders(batch_size, img_size) |
| opt = torch.optim.Adam(classifier.parameters(), lr=lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| best_acc = 0 |
|
|
| for epoch in range(1, epochs + 1): |
| classifier.train() |
| total_loss, correct, total = 0, 0, 0 |
| t0 = time.time() |
|
|
| pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}", |
| bar_format='{l_bar}{bar:20}{r_bar}') |
| for images, labels in pbar: |
| images = images.to(device) |
| labels = labels.to(device) |
|
|
| if mode == 'omega': |
| out, hidden = freckles_wrapper(images) |
| logits = classifier(out['svd'], hidden, gh, gw) |
| else: |
| logits = classifier(images) |
|
|
| loss = F.cross_entropy(logits, labels, label_smoothing=0.1) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0) |
| opt.step() |
|
|
| total_loss += loss.item() * len(labels) |
| correct += (logits.argmax(-1) == labels).sum().item() |
| total += len(labels) |
| pbar.set_postfix_str(f"loss={loss.item():.4f} acc={correct/total:.1%}") |
|
|
| sched.step() |
| train_acc = correct / total |
| train_loss = total_loss / total |
|
|
| |
| classifier.eval() |
| test_correct, test_total = 0, 0 |
| per_class_correct = torch.zeros(10) |
| per_class_total = torch.zeros(10) |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images = images.to(device) |
| labels = labels.to(device) |
|
|
| if mode == 'omega': |
| out, hidden = freckles_wrapper(images) |
| logits = classifier(out['svd'], hidden, gh, gw) |
| else: |
| logits = classifier(images) |
|
|
| preds = logits.argmax(-1) |
| test_correct += (preds == labels).sum().item() |
| test_total += len(labels) |
| for c in range(10): |
| mask = labels == c |
| per_class_correct[c] += (preds[mask] == labels[mask]).sum().item() |
| per_class_total[c] += mask.sum().item() |
|
|
| test_acc = test_correct / test_total |
| epoch_time = time.time() - t0 |
| per_class_acc = per_class_correct / (per_class_total + 1e-8) |
| worst_class = per_class_acc.argmin().item() |
| best_class = per_class_acc.argmax().item() |
|
|
| print(f" ep{epoch:3d} | loss={train_loss:.4f} train={train_acc:.1%} " |
| f"test={test_acc:.1%} | best={CIFAR_CLASSES[best_class]}={per_class_acc[best_class]:.0%} " |
| f"worst={CIFAR_CLASSES[worst_class]}={per_class_acc[worst_class]:.0%} | {epoch_time:.0f}s") |
|
|
| if test_acc > best_acc: |
| best_acc = test_acc |
|
|
| if epoch % 5 == 0 or epoch == 1 or epoch == epochs: |
| print(f"\n {'class':<14s} {'acc':>6s}") |
| print(f" {'-'*22}") |
| for c in range(10): |
| bar = 'β' * int(per_class_acc[c] * 20) |
| print(f" {CIFAR_CLASSES[c]:<14s} {per_class_acc[c]:5.1%} {bar}") |
| print() |
|
|
| tag = "OMEGA v2" if mode == 'omega' else "BASELINE" |
| print(f"\n{'=' * 70}") |
| print(f"{tag} COMPLETE") |
| print(f" Best test accuracy: {best_acc:.1%}") |
| print(f" Classifier params: {n_params:,}") |
| print(f" Random chance: 10.0%") |
| print(f"{'=' * 70}") |
|
|
| return classifier, best_acc |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| torch.set_float32_matmul_precision('high') |
|
|
| MODE = 'both' |
| if len(sys.argv) > 1: |
| MODE = sys.argv[1] |
|
|
| results = {} |
|
|
| if MODE in ('omega', 'both'): |
| _, omega_acc = train_model( |
| mode='omega', epochs=30, batch_size=128, |
| lr=3e-4, d_model=128, n_heads=4, n_layers=4) |
| results['omega'] = omega_acc |
|
|
| if MODE in ('baseline', 'both'): |
| _, base_acc = train_model( |
| mode='baseline', epochs=30, batch_size=128, |
| lr=3e-4, d_model=128, n_heads=4, n_layers=4) |
| results['baseline'] = base_acc |
|
|
| if len(results) == 2: |
| print("\n" + "=" * 70) |
| print("HEAD-TO-HEAD COMPARISON") |
| print("=" * 70) |
| print(f" Omega v2 (hidden + geometric): {results['omega']:.1%}") |
| print(f" Baseline (raw patches): {results['baseline']:.1%}") |
| print(f" Delta: {results['omega'] - results['baseline']:+.1%}") |
| print(f" Random chance: 10.0%") |
| print("=" * 70) |