| """ |
| Reference PyTorch implementation for loading and running the WASM Interpreter Transformer. |
| |
| Usage: |
| python load_model.py |
| |
| This loads the safetensors weights and verifies the model architecture. |
| The model uses three attention modes: |
| - Hard-max for retrieval |
| - Sum-mode for depth accumulation |
| - Cross-attention for filesystem reads |
| """ |
|
|
| import json |
| import torch |
| import torch.nn.functional as F |
| from safetensors.torch import load_file |
|
|
|
|
| class WasmTransformerConfig: |
| def __init__(self, config_path="config.json"): |
| with open(config_path) as f: |
| cfg = json.load(f) |
| self.vocab_size = cfg["vocab_size"] |
| self.d_model = cfg["hidden_size"] |
| self.n_layers = cfg["num_hidden_layers"] |
| self.d_ffn = cfg["intermediate_size"] |
| self.d_h = cfg["head_dim"] |
| self.heads_per_layer: list[int] = cfg.get("heads_per_layer", [13, 1, 0, 8, 2, 1, 1, 4]) |
| self.sum_attention_heads: dict[str, list[int]] = cfg.get("sum_attention_heads", {}) |
| self.cross_attention_heads: dict[str, list[int]] = cfg.get("cross_attention_heads", {}) |
|
|
| def is_sum_head(self, layer: int, head: int) -> bool: |
| key = f"layer_{layer}" |
| return head in self.sum_attention_heads.get(key, []) |
|
|
| def is_cross_attention_head(self, layer: int, head: int) -> bool: |
| key = f"layer_{layer}" |
| return head in self.cross_attention_heads.get(key, []) |
|
|
|
|
| class WasmTransformer: |
| """Minimal inference implementation matching the TypeScript reference. |
| |
| Supports three attention modes: |
| - Hard-max (argmax): selects the best-matching key-value pair |
| - Sum-mode: accumulates all past value vectors (for cumulative depth tracking) |
| - Cross-attention: attends over external filesystem key-value store |
| """ |
|
|
| def __init__(self, config: WasmTransformerConfig, weights: dict[str, torch.Tensor]): |
| self.config = config |
| self.weights = weights |
| self.kv_keys: list[list[list[torch.Tensor]]] = [ |
| [[] for _ in range(config.heads_per_layer[l])] for l in range(config.n_layers) |
| ] |
| self.kv_values: list[list[list[torch.Tensor]]] = [ |
| [[] for _ in range(config.heads_per_layer[l])] for l in range(config.n_layers) |
| ] |
|
|
| def reset(self): |
| cfg = self.config |
| for l in range(cfg.n_layers): |
| for h in range(cfg.heads_per_layer[l]): |
| self.kv_keys[l][h] = [] |
| self.kv_values[l][h] = [] |
|
|
| def forward(self, token_id: int, pe: torch.Tensor, filesystem_kv: dict | None = None) -> int: |
| cfg = self.config |
| w = self.weights |
|
|
| x = w["embedding.weight"][token_id].clone() + pe |
|
|
| for l in range(cfg.n_layers): |
| attn_out = torch.zeros(cfg.d_model) |
| n_heads = cfg.heads_per_layer[l] |
|
|
| for h in range(n_heads): |
| prefix = f"layers.{l}.attention.heads.{h}" |
| wq = w[f"{prefix}.wq"] |
| wk = w[f"{prefix}.wk"] |
| wv = w[f"{prefix}.wv"] |
| wo = w[f"{prefix}.wo"] |
|
|
| q = wq @ x |
| k = wk @ x |
| v = wv @ x |
|
|
| if cfg.is_cross_attention_head(l, h): |
| |
| if filesystem_kv is not None: |
| keys = filesystem_kv.get("keys", []) |
| values = filesystem_kv.get("values", []) |
| if keys: |
| best_score = float("-inf") |
| best_idx = -1 |
| for j, kj in enumerate(keys): |
| score = (q * kj).sum().item() |
| if score >= best_score: |
| best_score = score |
| best_idx = j |
| if best_idx >= 0: |
| attn_out += wo @ values[best_idx] |
| else: |
| self.kv_keys[l][h].append(k.clone()) |
| self.kv_values[l][h].append(v.clone()) |
|
|
| num_keys = len(self.kv_keys[l][h]) |
|
|
| if cfg.is_sum_head(l, h): |
| acc = torch.zeros(cfg.d_h) |
| for j in range(num_keys - 1): |
| acc += self.kv_values[l][h][j] |
| attn_out += wo @ acc |
| else: |
| best_score = float("-inf") |
| best_idx = -1 |
| for j, kj in enumerate(self.kv_keys[l][h]): |
| score = (q * kj).sum().item() |
| if score >= best_score: |
| best_score = score |
| best_idx = j |
|
|
| if best_idx >= 0: |
| best_v = self.kv_values[l][h][best_idx] |
| attn_out += wo @ best_v |
|
|
| x = x + attn_out |
|
|
| |
| gate_w = w[f"layers.{l}.ffn.gate_weight"] |
| val_w = w[f"layers.{l}.ffn.value_weight"] |
| out_w = w[f"layers.{l}.ffn.output_weight"] |
|
|
| gate = gate_w @ x |
| val = val_w @ x |
| activated = F.relu(gate) * val |
| x = x + out_w @ activated |
|
|
| logits = w["unembed.weight"] @ x |
| return logits.argmax().item() |
|
|
|
|
| if __name__ == "__main__": |
| print("Loading WASM Interpreter Transformer...") |
| config = WasmTransformerConfig() |
| weights = load_file("model.safetensors") |
| model = WasmTransformer(config, weights) |
|
|
| print(f" Architecture: d_model={config.d_model}, n_layers={config.n_layers}") |
| print(f" Heads per layer: {config.heads_per_layer}") |
| total = sum(config.heads_per_layer) |
| print(f" Total attention heads: {total}") |
| print(f" Loaded {len(weights)} weight tensors") |
|
|
| sum_heads_info = [] |
| for layer_key, heads in config.sum_attention_heads.items(): |
| sum_heads_info.append(f" {layer_key}: heads {heads} (sum-mode)") |
| if sum_heads_info: |
| print(" Sum-attention heads:") |
| for info in sum_heads_info: |
| print(f" {info}") |
|
|
| cross_attn_info = [] |
| for layer_key, heads in config.cross_attention_heads.items(): |
| cross_attn_info.append(f" {layer_key}: heads {heads} (cross-attention / filesystem)") |
| if cross_attn_info: |
| print(" Cross-attention heads:") |
| for info in cross_attn_info: |
| print(f" {info}") |
| print() |
|
|
| print("Model loaded successfully!") |
| print("To run programs, use the TypeScript reference implementation") |
| print("which includes the program analysis and positional encoding pipeline.") |
|
|