File size: 8,863 Bytes
3c68cf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
Text-as-Image β€” Universal encoding via byte rasterization
============================================================
Text β†’ raw UTF-8 bytes β†’ pack into (3, H, W) β†’ PatchSVAE β†’ reconstruct β†’ unpack β†’ text

The model never knows it's looking at text. Bytes 0-255 are pixel values.
Every character, every language, every encoding. Just numbers in a grid.

For 64Γ—64:   3 Γ— 64 Γ— 64   = 12,288 bytes β‰ˆ 12KB of text per sample
For 128Γ—128: 3 Γ— 128 Γ— 128 = 49,152 bytes β‰ˆ 48KB of text per sample

Usage:
    # Test existing Fresnel/Johanna on text (zero-shot)
    python text_svae.py --mode test --model AbstractPhil/svae-fresnel-128

    # Train a text-specific variant
    python text_svae.py --mode train --size 128

    # Evaluate byte-level accuracy
    python text_svae.py --mode eval --checkpoint /content/checkpoints/best.pt
"""

import os
import torch
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset


# ── Text ↔ Tensor Conversion ────────────────────────────────────

def text_to_tensor(text: str, H: int = 128, W: int = 128, C: int = 3) -> torch.Tensor:
    """Pack UTF-8 bytes into a (C, H, W) tensor normalized to [-1, 1].

    Args:
        text: Raw string
        H, W: Spatial dimensions
        C: Channels (3 for compatibility with existing models)

    Returns:
        (C, H, W) float tensor in [-1, 1]
        int: actual byte count before padding
    """
    raw = text.encode('utf-8')
    n_bytes = C * H * W
    actual_len = min(len(raw), n_bytes)

    # Pad or truncate
    if len(raw) < n_bytes:
        raw = raw + b'\x00' * (n_bytes - len(raw))
    else:
        raw = raw[:n_bytes]

    # To tensor: bytes 0-255 β†’ normalized [-1, 1]
    arr = np.frombuffer(raw, dtype=np.uint8).copy()
    tensor = torch.from_numpy(arr).float()
    tensor = (tensor / 127.5) - 1.0  # [0,255] β†’ [-1, 1]
    tensor = tensor.reshape(C, H, W)

    return tensor, actual_len


def tensor_to_text(tensor: torch.Tensor, byte_count: int = None) -> str:
    """Unpack a (C, H, W) tensor back to UTF-8 text.

    Args:
        tensor: (C, H, W) float tensor in [-1, 1]
        byte_count: if known, truncate to actual text length

    Returns:
        Decoded string (lossy characters replaced with ?)
    """
    # Denormalize: [-1, 1] β†’ [0, 255]
    arr = ((tensor.flatten() + 1.0) * 127.5).round().clamp(0, 255).byte().numpy()

    if byte_count is not None:
        arr = arr[:byte_count]

    # Decode with error handling
    return arr.tobytes().rstrip(b'\x00').decode('utf-8', errors='replace')


# ── Text Dataset ─────────────────────────────────────────────────

class TextAsImageDataset(torch.utils.data.Dataset):
    """Stream text from HuggingFace, pack as images.

    Each sample is a chunk of text rasterized to (3, H, W).
    The model sees it as a regular image.
    """
    def __init__(self, size=100000, img_size=128, text_field='text',
                 dataset_name='wikitext', dataset_config='wikitext-103-raw-v1',
                 split='train'):
        self.size = size
        self.img_size = img_size
        self.n_bytes = 3 * img_size * img_size
        self.text_field = text_field

        # Load and concatenate all text into one big buffer
        print(f"  Loading text corpus: {dataset_name}/{dataset_config}...")
        ds = load_dataset(dataset_name, dataset_config, split=split)
        all_text = '\n'.join([x[text_field] for x in ds if x[text_field].strip()])
        self.raw_bytes = all_text.encode('utf-8')
        print(f"  Corpus: {len(self.raw_bytes):,} bytes ({len(self.raw_bytes)/1024/1024:.1f}MB)")

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Random chunk from corpus
        max_start = len(self.raw_bytes) - self.n_bytes
        if max_start <= 0:
            start = 0
        else:
            start = torch.randint(0, max_start, (1,)).item()

        chunk = self.raw_bytes[start:start + self.n_bytes]
        if len(chunk) < self.n_bytes:
            chunk = chunk + b'\x00' * (self.n_bytes - len(chunk))

        arr = np.frombuffer(chunk, dtype=np.uint8).copy()
        tensor = torch.from_numpy(arr).float()
        tensor = (tensor / 127.5) - 1.0
        tensor = tensor.reshape(3, self.img_size, self.img_size)

        return tensor, 0  # label placeholder


def get_text_loaders(batch_size=128, img_size=128,
                     train_size=100000, val_size=5000):
    """Text-as-image data loaders from WikiText-103."""
    train_ds = TextAsImageDataset(size=train_size, img_size=img_size,
                                   split='train')
    val_ds = TextAsImageDataset(size=val_size, img_size=img_size,
                                 split='validation')
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=2, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=2, pin_memory=True)
    return train_loader, val_loader


# ── Zero-shot Test (existing model on text) ──────────────────────

def test_zero_shot(model_repo="AbstractPhil/svae-fresnel-128", device='cuda'):
    """Test an existing Fresnel/Johanna model on text without any text training."""
    from transformers import AutoModel

    print(f"\n{'='*70}")
    print(f"ZERO-SHOT TEXT TEST: {model_repo}")
    print(f"{'='*70}")

    model = AutoModel.from_pretrained(model_repo, trust_remote_code=True).to(device).eval()

    test_texts = [
        "Hello, world! This is a test of the Fresnel geometric encoder.",
        "The quick brown fox jumps over the lazy dog. 0123456789",
        "import torch\nmodel = AutoModel.from_pretrained('test')\noutput = model(x)",
        "To be, or not to be, that is the question.",
        "E = mcΒ² β€” Albert Einstein, 1905",
    ]

    img_size = model.config.image_size
    print(f"  Image size: {img_size}Γ—{img_size}")
    print(f"  Bytes per sample: {3 * img_size * img_size}")

    for text in test_texts:
        tensor, actual_len = text_to_tensor(text, img_size, img_size)
        tensor = tensor.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(tensor)
            recon = output["recon"]
            mse = F.mse_loss(recon, tensor).item()

        recovered = tensor_to_text(recon.squeeze(0).cpu(), actual_len)

        # Byte accuracy
        orig_bytes = ((tensor.squeeze(0).flatten() + 1.0) * 127.5).round().clamp(0, 255).byte()
        recon_bytes = ((recon.squeeze(0).cpu().flatten() + 1.0) * 127.5).round().clamp(0, 255).byte()
        exact_match = (orig_bytes[:actual_len] == recon_bytes[:actual_len]).float().mean().item()

        print(f"\n  Input:    '{text[:80]}{'...' if len(text) > 80 else ''}'")
        print(f"  Output:   '{recovered[:80]}{'...' if len(recovered) > 80 else ''}'")
        print(f"  MSE:      {mse:.6f}")
        print(f"  Byte acc: {exact_match*100:.1f}%")
        print(f"  Exact:    {'YES' if recovered.rstrip('\x00') == text else 'NO'}")


# ── Evaluation ───────────────────────────────────────────────────

def eval_text_reconstruction(model, test_loader, device='cuda', n_batches=10):
    """Evaluate byte-level reconstruction accuracy on text."""
    model.eval()
    total_mse, total_acc, n = 0, 0, 0

    with torch.no_grad():
        for i, (images, _) in enumerate(test_loader):
            if i >= n_batches:
                break
            images = images.to(device)
            out = model(images)
            recon = out['recon']

            mse = F.mse_loss(recon, images).item()

            # Byte accuracy
            orig_bytes = ((images.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long()
            recon_bytes = ((recon.flatten(1) + 1.0) * 127.5).round().clamp(0, 255).long()
            acc = (orig_bytes == recon_bytes).float().mean().item()

            total_mse += mse * len(images)
            total_acc += acc * len(images)
            n += len(images)

    return {
        'mse': total_mse / n,
        'byte_accuracy': total_acc / n,
        'n_samples': n,
    }


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='test', choices=['test', 'eval'])
    parser.add_argument('--model', default='AbstractPhil/svae-fresnel-128')
    parser.add_argument('--device', default='cuda')
    args = parser.parse_args()

    if args.mode == 'test':
        test_zero_shot(args.model, args.device)