""" Reference PyTorch implementation for loading and running the WASM Interpreter Transformer. Usage: python load_model.py This loads the safetensors weights and verifies the model architecture. The model uses three attention modes: - Hard-max for retrieval - Sum-mode for depth accumulation - Cross-attention for filesystem reads """ import json import torch import torch.nn.functional as F from safetensors.torch import load_file class WasmTransformerConfig: def __init__(self, config_path="config.json"): with open(config_path) as f: cfg = json.load(f) self.vocab_size = cfg["vocab_size"] self.d_model = cfg["hidden_size"] self.n_layers = cfg["num_hidden_layers"] self.d_ffn = cfg["intermediate_size"] self.d_h = cfg["head_dim"] self.heads_per_layer: list[int] = cfg.get("heads_per_layer", [13, 1, 0, 8, 2, 1, 1, 4]) self.sum_attention_heads: dict[str, list[int]] = cfg.get("sum_attention_heads", {}) self.cross_attention_heads: dict[str, list[int]] = cfg.get("cross_attention_heads", {}) def is_sum_head(self, layer: int, head: int) -> bool: key = f"layer_{layer}" return head in self.sum_attention_heads.get(key, []) def is_cross_attention_head(self, layer: int, head: int) -> bool: key = f"layer_{layer}" return head in self.cross_attention_heads.get(key, []) class WasmTransformer: """Minimal inference implementation matching the TypeScript reference. Supports three attention modes: - Hard-max (argmax): selects the best-matching key-value pair - Sum-mode: accumulates all past value vectors (for cumulative depth tracking) - Cross-attention: attends over external filesystem key-value store """ def __init__(self, config: WasmTransformerConfig, weights: dict[str, torch.Tensor]): self.config = config self.weights = weights self.kv_keys: list[list[list[torch.Tensor]]] = [ [[] for _ in range(config.heads_per_layer[l])] for l in range(config.n_layers) ] self.kv_values: list[list[list[torch.Tensor]]] = [ [[] for _ in range(config.heads_per_layer[l])] for l in range(config.n_layers) ] def reset(self): cfg = self.config for l in range(cfg.n_layers): for h in range(cfg.heads_per_layer[l]): self.kv_keys[l][h] = [] self.kv_values[l][h] = [] def forward(self, token_id: int, pe: torch.Tensor, filesystem_kv: dict | None = None) -> int: cfg = self.config w = self.weights x = w["embedding.weight"][token_id].clone() + pe for l in range(cfg.n_layers): attn_out = torch.zeros(cfg.d_model) n_heads = cfg.heads_per_layer[l] for h in range(n_heads): prefix = f"layers.{l}.attention.heads.{h}" wq = w[f"{prefix}.wq"] wk = w[f"{prefix}.wk"] wv = w[f"{prefix}.wv"] wo = w[f"{prefix}.wo"] q = wq @ x k = wk @ x v = wv @ x if cfg.is_cross_attention_head(l, h): # Cross-attention: attend over external filesystem KV store if filesystem_kv is not None: keys = filesystem_kv.get("keys", []) values = filesystem_kv.get("values", []) if keys: best_score = float("-inf") best_idx = -1 for j, kj in enumerate(keys): score = (q * kj).sum().item() if score >= best_score: best_score = score best_idx = j if best_idx >= 0: attn_out += wo @ values[best_idx] else: self.kv_keys[l][h].append(k.clone()) self.kv_values[l][h].append(v.clone()) num_keys = len(self.kv_keys[l][h]) if cfg.is_sum_head(l, h): acc = torch.zeros(cfg.d_h) for j in range(num_keys - 1): acc += self.kv_values[l][h][j] attn_out += wo @ acc else: best_score = float("-inf") best_idx = -1 for j, kj in enumerate(self.kv_keys[l][h]): score = (q * kj).sum().item() if score >= best_score: best_score = score best_idx = j if best_idx >= 0: best_v = self.kv_values[l][h][best_idx] attn_out += wo @ best_v x = x + attn_out # SwiGLU FFN gate_w = w[f"layers.{l}.ffn.gate_weight"] val_w = w[f"layers.{l}.ffn.value_weight"] out_w = w[f"layers.{l}.ffn.output_weight"] gate = gate_w @ x val = val_w @ x activated = F.relu(gate) * val x = x + out_w @ activated logits = w["unembed.weight"] @ x return logits.argmax().item() if __name__ == "__main__": print("Loading WASM Interpreter Transformer...") config = WasmTransformerConfig() weights = load_file("model.safetensors") model = WasmTransformer(config, weights) print(f" Architecture: d_model={config.d_model}, n_layers={config.n_layers}") print(f" Heads per layer: {config.heads_per_layer}") total = sum(config.heads_per_layer) print(f" Total attention heads: {total}") print(f" Loaded {len(weights)} weight tensors") sum_heads_info = [] for layer_key, heads in config.sum_attention_heads.items(): sum_heads_info.append(f" {layer_key}: heads {heads} (sum-mode)") if sum_heads_info: print(" Sum-attention heads:") for info in sum_heads_info: print(f" {info}") cross_attn_info = [] for layer_key, heads in config.cross_attention_heads.items(): cross_attn_info.append(f" {layer_key}: heads {heads} (cross-attention / filesystem)") if cross_attn_info: print(" Cross-attention heads:") for info in cross_attn_info: print(f" {info}") print() print("Model loaded successfully!") print("To run programs, use the TypeScript reference implementation") print("which includes the program analysis and positional encoding pipeline.")