""" Text-as-Image — Universal encoding via byte rasterization ============================================================ Text → raw UTF-8 bytes → pack into (3, H, W) → PatchSVAE → reconstruct → unpack → text The model never knows it's looking at text. Bytes 0-255 are pixel values. Every character, every language, every encoding. Just numbers in a grid. For 64×64: 3 × 64 × 64 = 12,288 bytes ≈ 12KB of text per sample For 128×128: 3 × 128 × 128 = 49,152 bytes ≈ 48KB of text per sample Usage: # Test existing Fresnel/Johanna on text (zero-shot) python text_svae.py --mode test --model AbstractPhil/svae-fresnel-128 # Train a text-specific variant python text_svae.py --mode train --size 128 # Evaluate byte-level accuracy python text_svae.py --mode eval --checkpoint /content/checkpoints/best.pt """ import os import torch import torch.nn.functional as F import numpy as np from datasets import load_dataset # ── Text ↔ Tensor Conversion ──────────────────────────────────── def text_to_tensor(text: str, H: int = 128, W: int = 128, C: int = 3) -> torch.Tensor: """Pack UTF-8 bytes into a (C, H, W) tensor normalized to [-1, 1]. Args: text: Raw string H, W: Spatial dimensions C: Channels (3 for compatibility with existing models) Returns: (C, H, W) float tensor in [-1, 1] int: actual byte count before padding """ raw = text.encode('utf-8') n_bytes = C * H * W actual_len = min(len(raw), n_bytes) # Pad or truncate if len(raw) < n_bytes: raw = raw + b'\x00' * (n_bytes - len(raw)) else: raw = raw[:n_bytes] # To tensor: bytes 0-255 → normalized [-1, 1] arr = np.frombuffer(raw, dtype=np.uint8).copy() tensor = torch.from_numpy(arr).float() tensor = (tensor / 127.5) - 1.0 # [0,255] → [-1, 1] tensor = tensor.reshape(C, H, W) return tensor, actual_len def tensor_to_text(tensor: torch.Tensor, byte_count: int = None) -> str: """Unpack a (C, H, W) tensor back to UTF-8 text. Args: tensor: (C, H, W) float tensor in [-1, 1] byte_count: if known, truncate to actual text length Returns: Decoded string (lossy characters replaced with ?) """ # Denormalize: [-1, 1] → [0, 255] arr = ((tensor.flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy() if byte_count is not None: arr = arr[:byte_count] # Decode with error handling return arr.tobytes().rstrip(b'\x00').decode('utf-8', errors='replace') # ── Text Dataset ───────────────────────────────────────────────── class TextAsImageDataset(torch.utils.data.Dataset): """Stream text from HuggingFace, pack as images. Each sample is a chunk of text rasterized to (3, H, W). The model sees it as a regular image. """ def __init__(self, size=100000, img_size=128, text_field='text', dataset_name='wikitext', dataset_config='wikitext-103-raw-v1', split='train'): self.size = size self.img_size = img_size self.n_bytes = 3 * img_size * img_size self.text_field = text_field # Load and concatenate all text into one big buffer print(f" Loading text corpus: {dataset_name}/{dataset_config}...") ds = load_dataset(dataset_name, dataset_config, split=split) all_text = '\n'.join([x[text_field] for x in ds if x[text_field].strip()]) self.raw_bytes = all_text.encode('utf-8') print(f" Corpus: {len(self.raw_bytes):,} bytes ({len(self.raw_bytes)/1024/1024:.1f}MB)") def __len__(self): return self.size def __getitem__(self, idx): # Random chunk from corpus max_start = len(self.raw_bytes) - self.n_bytes if max_start <= 0: start = 0 else: start = torch.randint(0, max_start, (1,)).item() chunk = self.raw_bytes[start:start + self.n_bytes] if len(chunk) < self.n_bytes: chunk = chunk + b'\x00' * (self.n_bytes - len(chunk)) arr = np.frombuffer(chunk, dtype=np.uint8).copy() tensor = torch.from_numpy(arr).float() tensor = (tensor / 127.5) - 1.0 tensor = tensor.reshape(3, self.img_size, self.img_size) return tensor, 0 # label placeholder def get_text_loaders(batch_size=128, img_size=128, train_size=100000, val_size=5000): """Text-as-image data loaders from WikiText-103.""" train_ds = TextAsImageDataset(size=train_size, img_size=img_size, split='train') val_ds = TextAsImageDataset(size=val_size, img_size=img_size, split='validation') train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) return train_loader, val_loader # ── Zero-shot Test (existing model on text) ────────────────────── def test_zero_shot(model_repo="AbstractPhil/svae-fresnel-128", device='cuda'): """Test an existing Fresnel/Johanna model on text without any text training.""" from transformers import AutoModel print(f"\n{'='*70}") print(f"ZERO-SHOT TEXT TEST: {model_repo}") print(f"{'='*70}") model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(device).eval() test_texts = [ "Hello, world! This is a test of the Fresnel geometric encoder.", "The quick brown fox jumps over the lazy dog. 0123456789", "import torch\nmodel = AutoModel.from_pretrained('test')\noutput = model(x)", "To be, or not to be, that is the question.", "E = mc² — Albert Einstein, 1905", ] img_size = model.config.image_size print(f" Image size: {img_size}×{img_size}") print(f" Bytes per sample: {3 * img_size * img_size}") for text in test_texts: tensor, actual_len = text_to_tensor(text, img_size, img_size) tensor = tensor.unsqueeze(0).to(device) with torch.no_grad(): output = model(tensor) recon = output["recon"] mse = F.mse_loss(recon, tensor).item() recovered = tensor_to_text(recon.squeeze(0).cpu(), actual_len) # Byte accuracy orig_bytes = ((tensor.squeeze(0).flatten() + 1.0) * 127.5).round().clamp(0, 255).byte() recon_bytes = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte() exact_match = (orig_bytes[:actual_len] == recon_bytes[:actual_len]).float().mean().item() print(f"\n Input: '{text[:80]}{'...' if len(text) > 80 else ''}'") print(f" Output: '{recovered[:80]}{'...' if len(recovered) > 80 else ''}'") print(f" MSE: {mse:.6f}") print(f" Byte acc: {exact_match*100:.1f}%") print(f" Exact: {'YES' if recovered.rstrip('\x00') == text else 'NO'}") # ── Evaluation ─────────────────────────────────────────────────── def eval_text_reconstruction(model, test_loader, device='cuda', n_batches=10): """Evaluate byte-level reconstruction accuracy on text.""" model.eval() total_mse, total_acc, n = 0, 0, 0 with torch.no_grad(): for i, (images, _) in enumerate(test_loader): if i >= n_batches: break images = images.to(device) out = model(images) recon = out['recon'] mse = F.mse_loss(recon, images).item() # Byte accuracy orig_bytes = ((images.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long() recon_bytes = ((recon.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long() acc = (orig_bytes == recon_bytes).float().mean().item() total_mse += mse * len(images) total_acc += acc * len(images) n += len(images) return { 'mse': total_mse / n, 'byte_accuracy': total_acc / n, 'n_samples': n, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--mode', default='test', choices=['test', 'eval']) parser.add_argument('--model', default='AbstractPhil/svae-fresnel-128') parser.add_argument('--device', default='cuda') args = parser.parse_args() if args.mode == 'test': test_zero_shot(args.model, args.device)