| import torch |
| import torch.nn.functional as F |
| from src.model import RippleGPT |
| from src.config import RippleConfig |
| import os |
| import pickle |
|
|
| device = 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
| def load_model(ckpt_path): |
| checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) |
| config = RippleConfig(**checkpoint['model_args']) |
| |
| |
| config.block_size = 1024 |
| model = RippleGPT(config) |
| |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| return model |
|
|
| def measure_perplexity(model, data_tensor, context_len): |
| """ |
| Measures how surprised the model is. Lower is better. |
| We test on a context length LARGER than training. |
| """ |
| max_batches = 10 |
| total_loss = 0 |
| with torch.no_grad(): |
| for i in range(max_batches): |
| |
| |
| if i * context_len + context_len + 1 > len(data_tensor): break |
| |
| x = data_tensor[i*context_len : i*context_len + context_len].unsqueeze(0).to(device) |
| y = data_tensor[i*context_len+1 : i*context_len + context_len+1].unsqueeze(0).to(device) |
| |
| _, loss = model(x, y) |
| total_loss += loss.item() |
| |
| avg_loss = total_loss / max_batches |
| perplexity = torch.exp(torch.tensor(avg_loss)) |
| return avg_loss, perplexity.item() |
|
|
| |
| print("Loading data...") |
| dataset_dir = 'data' |
| val_data_path = os.path.join(dataset_dir, 'val.bin') |
| meta_path = os.path.join(dataset_dir, 'meta.pkl') |
|
|
| if os.path.exists(val_data_path) and os.path.exists(meta_path): |
| print(f"Loading official validation data from {val_data_path}...") |
| import numpy as np |
| val_data_np = np.fromfile(val_data_path, dtype=np.uint16) |
| val_data = torch.from_numpy(val_data_np.astype(np.int64)) |
| else: |
| print("Official validation data not found. Downloading tinyshakespeare for demo...") |
| |
| import requests |
| text = requests.get("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").text |
| |
| |
| chars = sorted(list(set(text))) |
| stoi = { ch:i for i,ch in enumerate(chars) } |
| encode = lambda s: [stoi[c] for c in s] |
| val_data = torch.tensor(encode(text[int(0.9*len(text)):]), dtype=torch.long) |
|
|
| |
| print("Loading RippleGPT...") |
| ckpt_path = 'out/ckpt_best.pt' if os.path.exists('out/ckpt_best.pt') else 'out/ckpt.pt' |
| print(f"Loading checkpoint from {ckpt_path}") |
| model = load_model(ckpt_path) |
|
|
| |
| loss_256, ppl_256 = measure_perplexity(model, val_data, 256) |
| print(f"Context 256 (Trained size): Loss {loss_256:.4f}, Perplexity {ppl_256:.2f}") |
|
|
| |
| try: |
| loss_512, ppl_512 = measure_perplexity(model, val_data, 512) |
| print(f"Context 512 (2x Training): Loss {loss_512:.4f}, Perplexity {ppl_512:.2f}") |
| print("✅ EXTRAPOLATION SUCCESSFUL: Model handled 2x context length!") |
| except Exception as e: |
| print(f"❌ EXTRAPOLATION FAILED: {e}") |
|
|
| |
| try: |
| loss_1024, ppl_1024 = measure_perplexity(model, val_data, 1024) |
| print(f"Context 1024 (4x Training): Loss {loss_1024:.4f}, Perplexity {ppl_1024:.2f}") |
| except Exception as e: |
| pass |
|
|