geolip-SVAE / svae_alexandria_text_encoder_trainer.py
AbstractPhil's picture
Create svae_alexandria_text_encoder_trainer.py
1cf04f9 verified
raw
history blame
24.8 kB
"""
Alexandria β€” Text Reconstruction via Geometric Encoding
=========================================================
Wikipedia β†’ UTF-8 bytes β†’ (3, H, W) β†’ PatchSVAE β†’ reconstruct β†’ bytes β†’ text
The Library of Alexandria, rebuilt in geometry.
Text bytes are a structured subset of noise. Johanna already knows
how to invert the projection for arbitrary byte patterns. Alexandria
fine-tunes that knowledge specifically for text.
Byte accuracy is the metric that matters. A single wrong byte is
a wrong character. Text demands perfection.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import numpy as np
from tqdm import tqdm
# ── HuggingFace auth from Colab secrets ──
try:
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
except Exception:
pass
# ── SVD Backend ──────────────────────────────────────────────────
try:
from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N
HAS_FL = True
except ImportError:
HAS_FL = False
def gram_eigh_svd_fp64(A):
orig_dtype = A.dtype
with torch.amp.autocast('cuda', enabled=False):
A_d = A.double()
G = torch.bmm(A_d.transpose(1, 2), A_d)
G.diagonal(dim1=-2, dim2=-1).add_(1e-12)
eigenvalues, V = torch.linalg.eigh(G)
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
def svd_fp64(A):
B, M, N = A.shape
if HAS_FL and N <= _FL_MAX_N and A.is_cuda:
orig_dtype = A.dtype
with torch.amp.autocast('cuda', enabled=False):
A_d = A.double()
G = torch.bmm(A_d.transpose(1, 2), A_d)
eigenvalues, V = FLEigh()(G.float())
eigenvalues = eigenvalues.double().flip(-1)
V = V.double().flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
else:
return gram_eigh_svd_fp64(A)
# ── CV Monitoring ────────────────────────────────────────────────
def cayley_menger_vol2(points):
B, N, D = points.shape
pts = points.double()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
def cv_of(emb, n_samples=200):
if emb.dim() != 2 or emb.shape[0] < 5:
return 0.0
N, D = emb.shape
pool = min(N, 512)
indices = torch.stack([torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples)])
vol2 = cayley_menger_vol2(emb[:pool][indices])
valid = vol2 > 1e-20
if valid.sum() < 10:
return 0.0
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
# ── Wikipedia Text Dataset ───────────────────────────────────────
class WikiTextAsImage(torch.utils.data.Dataset):
"""Wikipedia text packed as (3, H, W) byte tensors.
Streams Wikipedia, concatenates into a byte buffer,
serves random chunks as "images". The model never knows
it's reading β€” it just sees numbers in a grid.
Byte normalization: [0, 255] β†’ [-1, 1]
"""
def __init__(self, size=200000, img_size=128, split='train'):
self.size = size
self.img_size = img_size
self.n_bytes = 3 * img_size * img_size
print(f" Loading Wikipedia ({split})...")
from datasets import load_dataset
ds = load_dataset('wikipedia', '20220301.en', split=split,
streaming=True)
# Accumulate enough text β€” need at least size * n_bytes
target_bytes = min(size * self.n_bytes, 500_000_000) # cap at 500MB
chunks = []
total = 0
for article in ds:
text = article['text']
if text.strip():
chunks.append(text)
total += len(text)
if total >= target_bytes:
break
self.raw_bytes = '\n'.join(chunks).encode('utf-8')
print(f" Corpus: {len(self.raw_bytes):,} bytes ({len(self.raw_bytes)/1024/1024:.1f}MB)")
print(f" Samples: {size:,} Γ— {self.n_bytes:,} bytes = {self.n_bytes} bytes/sample")
def __len__(self):
return self.size
def __getitem__(self, idx):
max_start = max(0, len(self.raw_bytes) - self.n_bytes)
start = torch.randint(0, max_start + 1, (1,)).item()
chunk = self.raw_bytes[start:start + self.n_bytes]
if len(chunk) < self.n_bytes:
chunk = chunk + b'\x00' * (self.n_bytes - len(chunk))
arr = np.frombuffer(chunk, dtype=np.uint8).copy()
tensor = torch.from_numpy(arr).float()
tensor = (tensor / 127.5) - 1.0 # [0,255] β†’ [-1, 1]
tensor = tensor.reshape(3, self.img_size, self.img_size)
return tensor, 0
# ── Patch Utilities ──────────────────────────────────────────────
def extract_patches(images, patch_size=16):
B, C, H, W = images.shape
gh, gw = H // patch_size, W // patch_size
patches = images.reshape(B, C, gh, patch_size, gw, patch_size)
patches = patches.permute(0, 2, 4, 1, 3, 5)
return patches.reshape(B, gh * gw, C * patch_size * patch_size), gh, gw
def stitch_patches(patches, gh, gw, patch_size=16):
B = patches.shape[0]
patches = patches.reshape(B, gh, gw, 3, patch_size, patch_size)
patches = patches.permute(0, 3, 1, 4, 2, 5)
return patches.reshape(B, 3, gh * patch_size, gw * patch_size)
class BoundarySmooth(nn.Module):
def __init__(self, channels=3, mid=16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(),
nn.Conv2d(mid, channels, 3, padding=1))
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def forward(self, x):
return x + self.net(x)
class SpectralCrossAttention(nn.Module):
def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0):
super().__init__()
self.n_heads = n_heads
self.head_dim = D // n_heads
self.max_alpha = max_alpha
assert D % n_heads == 0
self.qkv = nn.Linear(D, 3 * D)
self.out_proj = nn.Linear(D, D)
self.norm = nn.LayerNorm(D)
self.scale = self.head_dim ** -0.5
self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
@property
def alpha(self):
return self.max_alpha * torch.sigmoid(self.alpha_logits)
def forward(self, S):
B, N, D = S.shape
S_normed = self.norm(S)
qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
gate = torch.tanh(self.out_proj(out))
return S * (1.0 + self.alpha.unsqueeze(0).unsqueeze(0) * gate)
class PatchSVAE(nn.Module):
def __init__(self, matrix_v=256, D=16, patch_size=16, hidden=768,
depth=4, n_cross_layers=2):
super().__init__()
self.matrix_v = matrix_v
self.D = D
self.patch_size = patch_size
self.patch_dim = 3 * patch_size * patch_size
self.mat_dim = matrix_v * D
self.enc_in = nn.Linear(self.patch_dim, hidden)
self.enc_blocks = nn.ModuleList([
nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
nn.GELU(), nn.Linear(hidden, hidden))
for _ in range(depth)])
self.enc_out = nn.Linear(hidden, self.mat_dim)
self.dec_in = nn.Linear(self.mat_dim, hidden)
self.dec_blocks = nn.ModuleList([
nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
nn.GELU(), nn.Linear(hidden, hidden))
for _ in range(depth)])
self.dec_out = nn.Linear(hidden, self.patch_dim)
nn.init.orthogonal_(self.enc_out.weight)
self.cross_attn = nn.ModuleList([
SpectralCrossAttention(D, n_heads=min(4, D))
for _ in range(n_cross_layers)])
self.boundary_smooth = BoundarySmooth(channels=3, mid=16)
def encode_patches(self, patches):
B, N, _ = patches.shape
flat = patches.reshape(B * N, -1)
h = F.gelu(self.enc_in(flat))
for block in self.enc_blocks:
h = h + block(h)
M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D)
M = F.normalize(M, dim=-1)
U, S, Vt = svd_fp64(M)
U = U.reshape(B, N, self.matrix_v, self.D)
S = S.reshape(B, N, self.D)
Vt = Vt.reshape(B, N, self.D, self.D)
M = M.reshape(B, N, self.matrix_v, self.D)
S_coord = S
for layer in self.cross_attn:
S_coord = layer(S_coord)
return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M}
def decode_patches(self, U, S, Vt):
B, N, V, D = U.shape
U_flat = U.reshape(B * N, V, D)
S_flat = S.reshape(B * N, D)
Vt_flat = Vt.reshape(B * N, D, D)
M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1)))
for block in self.dec_blocks:
h = h + block(h)
return self.dec_out(h).reshape(B, N, -1)
def forward(self, images):
patches, gh, gw = extract_patches(images, self.patch_size)
svd = self.encode_patches(patches)
decoded = self.decode_patches(svd['U'], svd['S'], svd['Vt'])
recon = stitch_patches(decoded, gh, gw, self.patch_size)
recon = self.boundary_smooth(recon)
return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw}
@staticmethod
def effective_rank(S):
p = S / (S.sum(-1, keepdim=True) + 1e-8)
p = p.clamp(min=1e-8)
return (-(p * p.log()).sum(-1)).exp()
# ── Byte Accuracy ────────────────────────────────────────────────
def byte_accuracy(recon, target):
"""Compute exact byte recovery rate."""
orig = ((target.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long()
pred = ((recon.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long()
return (orig == pred).float().mean().item()
def sample_text_reconstruction(model, dataset, device, n=3):
"""Show actual text reconstruction examples."""
model.eval()
img_size = dataset.img_size
for i in range(n):
tensor, _ = dataset[i * 1000] # spread samples across corpus
tensor = tensor.unsqueeze(0).to(device)
with torch.no_grad():
out = model(tensor)
recon = out['recon']
# Decode original
orig_bytes = ((tensor.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy()
orig_text = orig_bytes.tobytes().decode('utf-8', errors='replace')[:200]
# Decode reconstruction
recon_bytes = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy()
recon_text = recon_bytes.tobytes().decode('utf-8', errors='replace')[:200]
acc = byte_accuracy(recon, tensor)
print(f"\n Sample {i+1}:")
print(f" Original: {repr(orig_text[:100])}")
print(f" Recon: {repr(recon_text[:100])}")
print(f" Byte acc: {acc*100:.1f}%")
# ── Training ─────────────────────────────────────────────────────
def train():
# ── Config ──
V, D, patch_size = 256, 16, 16
hidden, depth = 768, 4
n_cross_layers = 2
batch_size = 128
lr = 1e-4
epochs = 100
target_cv = 0.125
cv_weight, boost, sigma = 0.3, 0.5, 0.15
img_size = 128
save_dir = '/content/checkpoints'
save_every = 10
report_every = 2000
hf_repo = 'AbstractPhil/geolip-SVAE'
hf_version = 'v17_alexandria'
tb_dir = '/content/runs'
# ── Pretrained: load from Johanna omega or Fresnel ──
# Johanna omega knows arbitrary bytes. Fresnel knows images.
# Johanna is the better starting point for text.
pretrained_repo = 'AbstractPhil/geolip-SVAE'
pretrained_file = 'v16_johanna_omega/checkpoints/best.pt'
# Fallback: Gaussian Johanna if omega not ready yet
pretrained_fallback = 'v14_noise/checkpoints/epoch_0200.pt'
os.makedirs(save_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ── TensorBoard ──
from torch.utils.tensorboard import SummaryWriter
run_name = f"alexandria_V{V}_D{D}_h{hidden}_d{depth}"
tb_path = os.path.join(tb_dir, run_name)
writer = SummaryWriter(tb_path)
print(f" TensorBoard: {tb_path}")
# ── HuggingFace ──
hf_enabled = False
try:
from huggingface_hub import HfApi, hf_hub_download
api = HfApi()
api.whoami()
hf_enabled = True
hf_prefix = f"{hf_version}/checkpoints"
print(f" HuggingFace: {hf_repo}/{hf_prefix}")
except Exception as e:
print(f" HuggingFace: disabled ({e})")
def upload_to_hf(local_path, remote_name):
if not hf_enabled:
return
try:
api.upload_file(path_or_fileobj=local_path,
path_in_repo=f"{hf_prefix}/{remote_name}",
repo_id=hf_repo, repo_type="model")
print(f" ☁️ Uploaded: {hf_repo}/{hf_prefix}/{remote_name}")
except Exception as e:
print(f" ⚠️ HF upload failed: {e}")
# ── Load pretrained ──
print(f"\n Loading pretrained weights...")
ckpt = None
for fname in [pretrained_file, pretrained_fallback]:
try:
ckpt_path = hf_hub_download(repo_id=pretrained_repo,
filename=fname, repo_type="model")
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
print(f" Loaded: {fname}")
print(f" Epoch: {ckpt['epoch']}, MSE: {ckpt['test_mse']:.6f}")
break
except Exception as e:
print(f" {fname}: {e}")
# ── Model ──
model = PatchSVAE(matrix_v=V, D=D, patch_size=patch_size,
hidden=hidden, depth=depth,
n_cross_layers=n_cross_layers).to(device)
if ckpt is not None:
model.load_state_dict(ckpt['model_state_dict'], strict=True)
print(f" Loaded pretrained weights into model")
else:
print(f" ⚠️ No pretrained weights β€” training from scratch")
total_params = sum(p.numel() for p in model.parameters())
print(f" Params: {total_params:,}")
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
# ── Data: Wikipedia ──
print(f"\n Loading Wikipedia corpus...")
train_ds = WikiTextAsImage(size=200000, img_size=img_size, split='train')
val_ds = WikiTextAsImage(size=5000, img_size=img_size, split='train')
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
n_patches = (img_size // patch_size) ** 2
batches_per_epoch = len(train_loader)
print(f"\n ALEXANDRIA β€” The Library in Geometry")
print(f" Wikipedia β†’ UTF-8 bytes β†’ (3, {img_size}, {img_size}) β†’ PatchSVAE")
print(f" {n_patches} patches, ({V},{D}), hidden={hidden}, depth={depth}")
print(f" Batch={batch_size}, batches/epoch={batches_per_epoch}")
print(f" Bytes per sample: {3 * img_size * img_size:,}")
print(f" Text per sample: ~{3 * img_size * img_size // 5:,} words")
print("=" * 100)
print(f" {'ep':>3} {'batch':>7} | {'loss':>7} {'recon':>7} {'byteacc':>8} | "
f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | "
f"{'row_cv':>7} {'prox':>5} | {'S_delta':>7}")
print("-" * 100)
best_recon = float('inf')
global_batch = 0
def save_checkpoint(path, epoch, test_mse, extra=None, upload=True):
ckpt_out = {
'epoch': epoch, 'test_mse': test_mse,
'global_batch': global_batch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'scheduler_state_dict': sched.state_dict(),
'config': {
'V': V, 'D': D, 'patch_size': patch_size,
'hidden': hidden, 'depth': depth,
'n_cross_layers': n_cross_layers, 'target_cv': target_cv,
'dataset': 'wikipedia_en', 'modality': 'text',
'pretrained_from': pretrained_file,
'img_size': img_size, 'lr': lr,
},
}
if extra:
ckpt_out.update(extra)
torch.save(ckpt_out, path)
size_mb = os.path.getsize(path) / (1024 * 1024)
print(f" πŸ’Ύ Saved: {path} ({size_mb:.1f}MB, ep{epoch}, MSE={test_mse:.6f})")
if upload:
upload_to_hf(path, os.path.basename(path))
# ── Training Loop ──
for epoch in range(1, epochs + 1):
model.train()
total_loss, total_recon, total_acc, n = 0, 0, 0, 0
last_cv, last_prox, recon_w = target_cv, 1.0, 1.0 + boost
t0 = time.time()
pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
bar_format='{l_bar}{bar:20}{r_bar}')
for batch_idx, (images, _) in enumerate(pbar):
images = images.to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
with torch.no_grad():
if batch_idx % 50 == 0:
current_cv = cv_of(out['svd']['M'][0, 0])
if current_cv > 0:
last_cv = current_cv
delta = last_cv - target_cv
last_prox = math.exp(-delta**2 / (2 * sigma**2))
# Byte accuracy every 100 batches
if batch_idx % 100 == 0:
batch_acc = byte_accuracy(out['recon'], images)
total_acc += batch_acc
pbar.set_postfix_str(
f"mse={recon_loss.item():.4f} bytes={batch_acc*100:.0f}% cv={last_cv:.3f}",
refresh=False)
recon_w = 1.0 + boost * last_prox
cv_pen = cv_weight * (1.0 - last_prox)
cv_l = (last_cv - target_cv) ** 2
loss = recon_w * recon_loss + cv_pen * cv_l
loss.backward()
torch.nn.utils.clip_grad_norm_(model.cross_attn.parameters(), max_norm=0.5)
opt.step()
total_loss += loss.item() * len(images)
total_recon += recon_loss.item() * len(images)
n += len(images)
global_batch += 1
# ── Readout ──
if global_batch % report_every == 0:
model.eval()
with torch.no_grad():
test_imgs, _ = next(iter(test_loader))
test_imgs = test_imgs.to(device)
test_out = model(test_imgs)
test_mse = F.mse_loss(test_out['recon'], test_imgs).item()
test_acc = byte_accuracy(test_out['recon'], test_imgs)
S_mean = test_out['svd']['S'].mean(dim=(0, 1))
S_orig = test_out['svd']['S_orig'].mean(dim=(0, 1))
erank = model.effective_rank(
test_out['svd']['S'].reshape(-1, D)).mean().item()
s_delta = (S_mean - S_orig).abs().mean().item()
ratio = (S_mean[0] / (S_mean[-1] + 1e-8)).item()
writer.add_scalar('train/recon', total_recon / n, global_batch)
writer.add_scalar('test/recon_mse', test_mse, global_batch)
writer.add_scalar('test/byte_accuracy', test_acc, global_batch)
writer.add_scalar('geo/row_cv', last_cv, global_batch)
writer.add_scalar('geo/ratio', ratio, global_batch)
writer.add_scalar('geo/erank', erank, global_batch)
writer.add_scalar('geo/S0', S_mean[0].item(), global_batch)
writer.add_scalar('cross_attn/s_delta', s_delta, global_batch)
print(f"\n {epoch:3d} {global_batch:7d} | "
f"{total_loss/n:7.4f} {total_recon/n:7.4f} {test_acc*100:7.1f}% | "
f"{S_mean[0]:6.3f} {S_mean[-1]:6.3f} {ratio:5.2f} {erank:5.2f} | "
f"{last_cv:7.4f} {last_prox:5.3f} | "
f"{s_delta:7.5f}")
if test_mse < best_recon:
best_recon = test_mse
save_checkpoint(os.path.join(save_dir, 'best.pt'),
epoch, test_mse,
extra={'byte_accuracy': test_acc},
upload=False)
model.train()
pbar.close()
sched.step()
epoch_time = time.time() - t0
# ── Epoch eval ──
model.eval()
test_recon_total, test_acc_total, test_n = 0, 0, 0
with torch.no_grad():
for test_imgs, _ in test_loader:
test_imgs = test_imgs.to(device)
out = model(test_imgs)
test_recon_total += F.mse_loss(out['recon'], test_imgs).item() * len(test_imgs)
test_acc_total += byte_accuracy(out['recon'], test_imgs) * len(test_imgs)
test_n += len(test_imgs)
epoch_mse = test_recon_total / test_n
epoch_acc = test_acc_total / test_n
print(f" Epoch {epoch}: {epoch_time:.1f}s, MSE={epoch_mse:.6f}, "
f"bytes={epoch_acc*100:.1f}%, best={best_recon:.6f}")
# Text samples every 10 epochs
if epoch % 10 == 0 or epoch == 1:
print(f"\n ── Text Reconstruction Samples ──")
sample_text_reconstruction(model, train_ds, device, n=3)
if epoch_mse < best_recon:
best_recon = epoch_mse
save_checkpoint(os.path.join(save_dir, 'best.pt'),
epoch, epoch_mse,
extra={'byte_accuracy': epoch_acc},
upload=False)
if epoch % save_every == 0:
save_checkpoint(os.path.join(save_dir, f'epoch_{epoch:04d}.pt'),
epoch, epoch_mse,
extra={'byte_accuracy': epoch_acc})
best_path = os.path.join(save_dir, 'best.pt')
if os.path.exists(best_path):
upload_to_hf(best_path, 'best.pt')
writer.flush()
if hf_enabled:
try:
api.upload_folder(folder_path=tb_path,
path_in_repo=f"{hf_version}/tensorboard/{run_name}",
repo_id=hf_repo, repo_type="model")
print(f" ☁️ TB synced")
except:
pass
writer.close()
print(f"\n ALEXANDRIA TRAINING COMPLETE")
print(f" Best MSE: {best_recon:.6f}")
print(f" Checkpoints: {save_dir}/")
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
train()