| """ |
| Text-as-Image β Universal encoding via byte rasterization |
| ============================================================ |
| Text β raw UTF-8 bytes β pack into (3, H, W) β PatchSVAE β reconstruct β unpack β text |
| |
| The model never knows it's looking at text. Bytes 0-255 are pixel values. |
| Every character, every language, every encoding. Just numbers in a grid. |
| |
| For 64Γ64: 3 Γ 64 Γ 64 = 12,288 bytes β 12KB of text per sample |
| For 128Γ128: 3 Γ 128 Γ 128 = 49,152 bytes β 48KB of text per sample |
| |
| Usage: |
| # Test existing Fresnel/Johanna on text (zero-shot) |
| python text_svae.py --mode test --model AbstractPhil/svae-fresnel-128 |
| |
| # Train a text-specific variant |
| python text_svae.py --mode train --size 128 |
| |
| # Evaluate byte-level accuracy |
| python text_svae.py --mode eval --checkpoint /content/checkpoints/best.pt |
| """ |
|
|
| import os |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from datasets import load_dataset |
|
|
|
|
| |
|
|
| def text_to_tensor(text: str, H: int = 128, W: int = 128, C: int = 3) -> torch.Tensor: |
| """Pack UTF-8 bytes into a (C, H, W) tensor normalized to [-1, 1]. |
| |
| Args: |
| text: Raw string |
| H, W: Spatial dimensions |
| C: Channels (3 for compatibility with existing models) |
| |
| Returns: |
| (C, H, W) float tensor in [-1, 1] |
| int: actual byte count before padding |
| """ |
| raw = text.encode('utf-8') |
| n_bytes = C * H * W |
| actual_len = min(len(raw), n_bytes) |
|
|
| |
| if len(raw) < n_bytes: |
| raw = raw + b'\x00' * (n_bytes - len(raw)) |
| else: |
| raw = raw[:n_bytes] |
|
|
| |
| arr = np.frombuffer(raw, dtype=np.uint8).copy() |
| tensor = torch.from_numpy(arr).float() |
| tensor = (tensor / 127.5) - 1.0 |
| tensor = tensor.reshape(C, H, W) |
|
|
| return tensor, actual_len |
|
|
|
|
| def tensor_to_text(tensor: torch.Tensor, byte_count: int = None) -> str: |
| """Unpack a (C, H, W) tensor back to UTF-8 text. |
| |
| Args: |
| tensor: (C, H, W) float tensor in [-1, 1] |
| byte_count: if known, truncate to actual text length |
| |
| Returns: |
| Decoded string (lossy characters replaced with ?) |
| """ |
| |
| arr = ((tensor.flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy() |
|
|
| if byte_count is not None: |
| arr = arr[:byte_count] |
|
|
| |
| return arr.tobytes().rstrip(b'\x00').decode('utf-8', errors='replace') |
|
|
|
|
| |
|
|
| class TextAsImageDataset(torch.utils.data.Dataset): |
| """Stream text from HuggingFace, pack as images. |
| |
| Each sample is a chunk of text rasterized to (3, H, W). |
| The model sees it as a regular image. |
| """ |
| def __init__(self, size=100000, img_size=128, text_field='text', |
| dataset_name='wikitext', dataset_config='wikitext-103-raw-v1', |
| split='train'): |
| self.size = size |
| self.img_size = img_size |
| self.n_bytes = 3 * img_size * img_size |
| self.text_field = text_field |
|
|
| |
| print(f" Loading text corpus: {dataset_name}/{dataset_config}...") |
| ds = load_dataset(dataset_name, dataset_config, split=split) |
| all_text = '\n'.join([x[text_field] for x in ds if x[text_field].strip()]) |
| self.raw_bytes = all_text.encode('utf-8') |
| print(f" Corpus: {len(self.raw_bytes):,} bytes ({len(self.raw_bytes)/1024/1024:.1f}MB)") |
|
|
| def __len__(self): |
| return self.size |
|
|
| def __getitem__(self, idx): |
| |
| max_start = len(self.raw_bytes) - self.n_bytes |
| if max_start <= 0: |
| start = 0 |
| else: |
| start = torch.randint(0, max_start, (1,)).item() |
|
|
| chunk = self.raw_bytes[start:start + self.n_bytes] |
| if len(chunk) < self.n_bytes: |
| chunk = chunk + b'\x00' * (self.n_bytes - len(chunk)) |
|
|
| arr = np.frombuffer(chunk, dtype=np.uint8).copy() |
| tensor = torch.from_numpy(arr).float() |
| tensor = (tensor / 127.5) - 1.0 |
| tensor = tensor.reshape(3, self.img_size, self.img_size) |
|
|
| return tensor, 0 |
|
|
|
|
| def get_text_loaders(batch_size=128, img_size=128, |
| train_size=100000, val_size=5000): |
| """Text-as-image data loaders from WikiText-103.""" |
| train_ds = TextAsImageDataset(size=train_size, img_size=img_size, |
| split='train') |
| val_ds = TextAsImageDataset(size=val_size, img_size=img_size, |
| split='validation') |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=2, pin_memory=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=batch_size, shuffle=False, |
| num_workers=2, pin_memory=True) |
| return train_loader, val_loader |
|
|
|
|
| |
|
|
| def test_zero_shot(model_repo="AbstractPhil/svae-fresnel-128", device='cuda'): |
| """Test an existing Fresnel/Johanna model on text without any text training.""" |
| from transformers import AutoModel |
|
|
| print(f"\n{'='*70}") |
| print(f"ZERO-SHOT TEXT TEST: {model_repo}") |
| print(f"{'='*70}") |
|
|
| model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(device).eval() |
|
|
| test_texts = [ |
| "Hello, world! This is a test of the Fresnel geometric encoder.", |
| "The quick brown fox jumps over the lazy dog. 0123456789", |
| "import torch\nmodel = AutoModel.from_pretrained('test')\noutput = model(x)", |
| "To be, or not to be, that is the question.", |
| "E = mcΒ² β Albert Einstein, 1905", |
| ] |
|
|
| img_size = model.config.image_size |
| print(f" Image size: {img_size}Γ{img_size}") |
| print(f" Bytes per sample: {3 * img_size * img_size}") |
|
|
| for text in test_texts: |
| tensor, actual_len = text_to_tensor(text, img_size, img_size) |
| tensor = tensor.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| output = model(tensor) |
| recon = output["recon"] |
| mse = F.mse_loss(recon, tensor).item() |
|
|
| recovered = tensor_to_text(recon.squeeze(0).cpu(), actual_len) |
|
|
| |
| orig_bytes = ((tensor.squeeze(0).flatten() + 1.0) * 127.5).round().clamp(0, 255).byte() |
| recon_bytes = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte() |
| exact_match = (orig_bytes[:actual_len] == recon_bytes[:actual_len]).float().mean().item() |
|
|
| print(f"\n Input: '{text[:80]}{'...' if len(text) > 80 else ''}'") |
| print(f" Output: '{recovered[:80]}{'...' if len(recovered) > 80 else ''}'") |
| print(f" MSE: {mse:.6f}") |
| print(f" Byte acc: {exact_match*100:.1f}%") |
| print(f" Exact: {'YES' if recovered.rstrip('\x00') == text else 'NO'}") |
|
|
|
|
| |
|
|
| def eval_text_reconstruction(model, test_loader, device='cuda', n_batches=10): |
| """Evaluate byte-level reconstruction accuracy on text.""" |
| model.eval() |
| total_mse, total_acc, n = 0, 0, 0 |
|
|
| with torch.no_grad(): |
| for i, (images, _) in enumerate(test_loader): |
| if i >= n_batches: |
| break |
| images = images.to(device) |
| out = model(images) |
| recon = out['recon'] |
|
|
| mse = F.mse_loss(recon, images).item() |
|
|
| |
| orig_bytes = ((images.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long() |
| recon_bytes = ((recon.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long() |
| acc = (orig_bytes == recon_bytes).float().mean().item() |
|
|
| total_mse += mse * len(images) |
| total_acc += acc * len(images) |
| n += len(images) |
|
|
| return { |
| 'mse': total_mse / n, |
| 'byte_accuracy': total_acc / n, |
| 'n_samples': n, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--mode', default='test', choices=['test', 'eval']) |
| parser.add_argument('--model', default='AbstractPhil/svae-fresnel-128') |
| parser.add_argument('--device', default='cuda') |
| args = parser.parse_args() |
|
|
| if args.mode == 'test': |
| test_zero_shot(args.model, args.device) |