eastlondoner's picture
Super-squash branch 'main' using huggingface_hub
a34f874
"""
Reference PyTorch implementation for loading and running the WASM Interpreter Transformer.
Usage:
python load_model.py
This loads the safetensors weights and verifies the model architecture.
The model uses three attention modes:
- Hard-max for retrieval
- Sum-mode for depth accumulation
- Cross-attention for filesystem reads
"""
import json
import torch
import torch.nn.functional as F
from safetensors.torch import load_file
class WasmTransformerConfig:
def __init__(self, config_path="config.json"):
with open(config_path) as f:
cfg = json.load(f)
self.vocab_size = cfg["vocab_size"]
self.d_model = cfg["hidden_size"]
self.n_layers = cfg["num_hidden_layers"]
self.d_ffn = cfg["intermediate_size"]
self.d_h = cfg["head_dim"]
self.heads_per_layer: list[int] = cfg.get("heads_per_layer", [13, 1, 0, 8, 2, 1, 1, 4])
self.sum_attention_heads: dict[str, list[int]] = cfg.get("sum_attention_heads", {})
self.cross_attention_heads: dict[str, list[int]] = cfg.get("cross_attention_heads", {})
def is_sum_head(self, layer: int, head: int) -> bool:
key = f"layer_{layer}"
return head in self.sum_attention_heads.get(key, [])
def is_cross_attention_head(self, layer: int, head: int) -> bool:
key = f"layer_{layer}"
return head in self.cross_attention_heads.get(key, [])
class WasmTransformer:
"""Minimal inference implementation matching the TypeScript reference.
Supports three attention modes:
- Hard-max (argmax): selects the best-matching key-value pair
- Sum-mode: accumulates all past value vectors (for cumulative depth tracking)
- Cross-attention: attends over external filesystem key-value store
"""
def __init__(self, config: WasmTransformerConfig, weights: dict[str, torch.Tensor]):
self.config = config
self.weights = weights
self.kv_keys: list[list[list[torch.Tensor]]] = [
[[] for _ in range(config.heads_per_layer[l])] for l in range(config.n_layers)
]
self.kv_values: list[list[list[torch.Tensor]]] = [
[[] for _ in range(config.heads_per_layer[l])] for l in range(config.n_layers)
]
def reset(self):
cfg = self.config
for l in range(cfg.n_layers):
for h in range(cfg.heads_per_layer[l]):
self.kv_keys[l][h] = []
self.kv_values[l][h] = []
def forward(self, token_id: int, pe: torch.Tensor, filesystem_kv: dict | None = None) -> int:
cfg = self.config
w = self.weights
x = w["embedding.weight"][token_id].clone() + pe
for l in range(cfg.n_layers):
attn_out = torch.zeros(cfg.d_model)
n_heads = cfg.heads_per_layer[l]
for h in range(n_heads):
prefix = f"layers.{l}.attention.heads.{h}"
wq = w[f"{prefix}.wq"]
wk = w[f"{prefix}.wk"]
wv = w[f"{prefix}.wv"]
wo = w[f"{prefix}.wo"]
q = wq @ x
k = wk @ x
v = wv @ x
if cfg.is_cross_attention_head(l, h):
# Cross-attention: attend over external filesystem KV store
if filesystem_kv is not None:
keys = filesystem_kv.get("keys", [])
values = filesystem_kv.get("values", [])
if keys:
best_score = float("-inf")
best_idx = -1
for j, kj in enumerate(keys):
score = (q * kj).sum().item()
if score >= best_score:
best_score = score
best_idx = j
if best_idx >= 0:
attn_out += wo @ values[best_idx]
else:
self.kv_keys[l][h].append(k.clone())
self.kv_values[l][h].append(v.clone())
num_keys = len(self.kv_keys[l][h])
if cfg.is_sum_head(l, h):
acc = torch.zeros(cfg.d_h)
for j in range(num_keys - 1):
acc += self.kv_values[l][h][j]
attn_out += wo @ acc
else:
best_score = float("-inf")
best_idx = -1
for j, kj in enumerate(self.kv_keys[l][h]):
score = (q * kj).sum().item()
if score >= best_score:
best_score = score
best_idx = j
if best_idx >= 0:
best_v = self.kv_values[l][h][best_idx]
attn_out += wo @ best_v
x = x + attn_out
# SwiGLU FFN
gate_w = w[f"layers.{l}.ffn.gate_weight"]
val_w = w[f"layers.{l}.ffn.value_weight"]
out_w = w[f"layers.{l}.ffn.output_weight"]
gate = gate_w @ x
val = val_w @ x
activated = F.relu(gate) * val
x = x + out_w @ activated
logits = w["unembed.weight"] @ x
return logits.argmax().item()
if __name__ == "__main__":
print("Loading WASM Interpreter Transformer...")
config = WasmTransformerConfig()
weights = load_file("model.safetensors")
model = WasmTransformer(config, weights)
print(f" Architecture: d_model={config.d_model}, n_layers={config.n_layers}")
print(f" Heads per layer: {config.heads_per_layer}")
total = sum(config.heads_per_layer)
print(f" Total attention heads: {total}")
print(f" Loaded {len(weights)} weight tensors")
sum_heads_info = []
for layer_key, heads in config.sum_attention_heads.items():
sum_heads_info.append(f" {layer_key}: heads {heads} (sum-mode)")
if sum_heads_info:
print(" Sum-attention heads:")
for info in sum_heads_info:
print(f" {info}")
cross_attn_info = []
for layer_key, heads in config.cross_attention_heads.items():
cross_attn_info.append(f" {layer_key}: heads {heads} (cross-attention / filesystem)")
if cross_attn_info:
print(" Cross-attention heads:")
for info in cross_attn_info:
print(f" {info}")
print()
print("Model loaded successfully!")
print("To run programs, use the TypeScript reference implementation")
print("which includes the program analysis and positional encoding pipeline.")