geolip-SVAE / svae_text_extraction_tester.py
AbstractPhil's picture
Rename svae_alexandria_text_trainer.py to svae_text_extraction_tester.py
8ed68e7 verified
raw
history blame
8.86 kB
"""
Text-as-Image β€” Universal encoding via byte rasterization
============================================================
Text β†’ raw UTF-8 bytes β†’ pack into (3, H, W) β†’ PatchSVAE β†’ reconstruct β†’ unpack β†’ text
The model never knows it's looking at text. Bytes 0-255 are pixel values.
Every character, every language, every encoding. Just numbers in a grid.
For 64Γ—64: 3 Γ— 64 Γ— 64 = 12,288 bytes β‰ˆ 12KB of text per sample
For 128Γ—128: 3 Γ— 128 Γ— 128 = 49,152 bytes β‰ˆ 48KB of text per sample
Usage:
# Test existing Fresnel/Johanna on text (zero-shot)
python text_svae.py --mode test --model AbstractPhil/svae-fresnel-128
# Train a text-specific variant
python text_svae.py --mode train --size 128
# Evaluate byte-level accuracy
python text_svae.py --mode eval --checkpoint /content/checkpoints/best.pt
"""
import os
import torch
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset
# ── Text ↔ Tensor Conversion ────────────────────────────────────
def text_to_tensor(text: str, H: int = 128, W: int = 128, C: int = 3) -> torch.Tensor:
"""Pack UTF-8 bytes into a (C, H, W) tensor normalized to [-1, 1].
Args:
text: Raw string
H, W: Spatial dimensions
C: Channels (3 for compatibility with existing models)
Returns:
(C, H, W) float tensor in [-1, 1]
int: actual byte count before padding
"""
raw = text.encode('utf-8')
n_bytes = C * H * W
actual_len = min(len(raw), n_bytes)
# Pad or truncate
if len(raw) < n_bytes:
raw = raw + b'\x00' * (n_bytes - len(raw))
else:
raw = raw[:n_bytes]
# To tensor: bytes 0-255 β†’ normalized [-1, 1]
arr = np.frombuffer(raw, dtype=np.uint8).copy()
tensor = torch.from_numpy(arr).float()
tensor = (tensor / 127.5) - 1.0 # [0,255] β†’ [-1, 1]
tensor = tensor.reshape(C, H, W)
return tensor, actual_len
def tensor_to_text(tensor: torch.Tensor, byte_count: int = None) -> str:
"""Unpack a (C, H, W) tensor back to UTF-8 text.
Args:
tensor: (C, H, W) float tensor in [-1, 1]
byte_count: if known, truncate to actual text length
Returns:
Decoded string (lossy characters replaced with ?)
"""
# Denormalize: [-1, 1] β†’ [0, 255]
arr = ((tensor.flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy()
if byte_count is not None:
arr = arr[:byte_count]
# Decode with error handling
return arr.tobytes().rstrip(b'\x00').decode('utf-8', errors='replace')
# ── Text Dataset ─────────────────────────────────────────────────
class TextAsImageDataset(torch.utils.data.Dataset):
"""Stream text from HuggingFace, pack as images.
Each sample is a chunk of text rasterized to (3, H, W).
The model sees it as a regular image.
"""
def __init__(self, size=100000, img_size=128, text_field='text',
dataset_name='wikitext', dataset_config='wikitext-103-raw-v1',
split='train'):
self.size = size
self.img_size = img_size
self.n_bytes = 3 * img_size * img_size
self.text_field = text_field
# Load and concatenate all text into one big buffer
print(f" Loading text corpus: {dataset_name}/{dataset_config}...")
ds = load_dataset(dataset_name, dataset_config, split=split)
all_text = '\n'.join([x[text_field] for x in ds if x[text_field].strip()])
self.raw_bytes = all_text.encode('utf-8')
print(f" Corpus: {len(self.raw_bytes):,} bytes ({len(self.raw_bytes)/1024/1024:.1f}MB)")
def __len__(self):
return self.size
def __getitem__(self, idx):
# Random chunk from corpus
max_start = len(self.raw_bytes) - self.n_bytes
if max_start <= 0:
start = 0
else:
start = torch.randint(0, max_start, (1,)).item()
chunk = self.raw_bytes[start:start + self.n_bytes]
if len(chunk) < self.n_bytes:
chunk = chunk + b'\x00' * (self.n_bytes - len(chunk))
arr = np.frombuffer(chunk, dtype=np.uint8).copy()
tensor = torch.from_numpy(arr).float()
tensor = (tensor / 127.5) - 1.0
tensor = tensor.reshape(3, self.img_size, self.img_size)
return tensor, 0 # label placeholder
def get_text_loaders(batch_size=128, img_size=128,
train_size=100000, val_size=5000):
"""Text-as-image data loaders from WikiText-103."""
train_ds = TextAsImageDataset(size=train_size, img_size=img_size,
split='train')
val_ds = TextAsImageDataset(size=val_size, img_size=img_size,
split='validation')
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=2, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=2, pin_memory=True)
return train_loader, val_loader
# ── Zero-shot Test (existing model on text) ──────────────────────
def test_zero_shot(model_repo="AbstractPhil/svae-fresnel-128", device='cuda'):
"""Test an existing Fresnel/Johanna model on text without any text training."""
from transformers import AutoModel
print(f"\n{'='*70}")
print(f"ZERO-SHOT TEXT TEST: {model_repo}")
print(f"{'='*70}")
model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(device).eval()
test_texts = [
"Hello, world! This is a test of the Fresnel geometric encoder.",
"The quick brown fox jumps over the lazy dog. 0123456789",
"import torch\nmodel = AutoModel.from_pretrained('test')\noutput = model(x)",
"To be, or not to be, that is the question.",
"E = mcΒ² β€” Albert Einstein, 1905",
]
img_size = model.config.image_size
print(f" Image size: {img_size}Γ—{img_size}")
print(f" Bytes per sample: {3 * img_size * img_size}")
for text in test_texts:
tensor, actual_len = text_to_tensor(text, img_size, img_size)
tensor = tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor)
recon = output["recon"]
mse = F.mse_loss(recon, tensor).item()
recovered = tensor_to_text(recon.squeeze(0).cpu(), actual_len)
# Byte accuracy
orig_bytes = ((tensor.squeeze(0).flatten() + 1.0) * 127.5).round().clamp(0, 255).byte()
recon_bytes = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte()
exact_match = (orig_bytes[:actual_len] == recon_bytes[:actual_len]).float().mean().item()
print(f"\n Input: '{text[:80]}{'...' if len(text) > 80 else ''}'")
print(f" Output: '{recovered[:80]}{'...' if len(recovered) > 80 else ''}'")
print(f" MSE: {mse:.6f}")
print(f" Byte acc: {exact_match*100:.1f}%")
print(f" Exact: {'YES' if recovered.rstrip('\x00') == text else 'NO'}")
# ── Evaluation ───────────────────────────────────────────────────
def eval_text_reconstruction(model, test_loader, device='cuda', n_batches=10):
"""Evaluate byte-level reconstruction accuracy on text."""
model.eval()
total_mse, total_acc, n = 0, 0, 0
with torch.no_grad():
for i, (images, _) in enumerate(test_loader):
if i >= n_batches:
break
images = images.to(device)
out = model(images)
recon = out['recon']
mse = F.mse_loss(recon, images).item()
# Byte accuracy
orig_bytes = ((images.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long()
recon_bytes = ((recon.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long()
acc = (orig_bytes == recon_bytes).float().mean().item()
total_mse += mse * len(images)
total_acc += acc * len(images)
n += len(images)
return {
'mse': total_mse / n,
'byte_accuracy': total_acc / n,
'n_samples': n,
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='test', choices=['test', 'eval'])
parser.add_argument('--model', default='AbstractPhil/svae-fresnel-128')
parser.add_argument('--device', default='cuda')
args = parser.parse_args()
if args.mode == 'test':
test_zero_shot(args.model, args.device)