Nanochat + WASM Coprocessor (Fused Preview)

A coprocessor-augmented language model that embeds a frozen WASM bytecode interpreter transformer inside a 4.9B-parameter language model (NanochatGPT d34), connected via trained cross-attention bridges. The LM and coprocessor execute in parallel within a single forward pass; from LM layer 10 onward, cross-attention lets the LM read the coprocessor's completed state. The LM generates text and WASM instructions; the coprocessor executes them deterministically; results flow back into the LM β€” enabling the model to think in computation.

Architecture

Component Details
Language Model (trained) NanochatGPT d34: d_model=2176, 34 layers, 17 heads, ~4.9B params (all trained)
WASM Coprocessor (frozen) WASM Interpreter Transformer: d_model=102, 23 layers, variable heads, ~840K params
Bridge (trained) 24 cross-attention layers (L10–L33, 4 heads each), WasmTokenEmbedding (260β†’2176), logit bias
Total parameters ~4.9B trained LM + ~840K frozen coprocessor β‰ˆ 4.9B total
Vocab 65536 text tokens (tiktoken BPE) + 523 WASM tokens (opcodes + operand bytes + feedback)

How the Coprocessor Works

  1. Text tokens β†’ processed normally by the LM
  2. WASM instruction tokens β†’ the LM emits them, and the frozen coprocessor immediately executes them
  3. Feedback tokens β†’ coprocessor results (REPL_RESULT, BRANCH_TAKEN, BRANCH_NOT_TAKEN) are fed back via cross-attention
  4. Lockstep execution β€” each WASM instruction is immediately followed by a feedback token, creating instruction-feedback pairs that the LM sees simultaneously

The coprocessor is a hand-compiled transformer that executes WASM bytecode via real matrix multiplications. It was not trained β€” every weight was set by a compiler. It supports arithmetic, comparisons, memory, local variables, filesystem I/O, and loops with conditional branching.

Cross-Attention Bridge

  • Layer 10: Primary injection point β€” cross-attention reads coprocessor hidden states
  • Layers 11-33: Additional cross-attention heads (gate-initialized near zero) refine the compute signal
  • WasmTokenEmbedding: Learned 260Γ—2176 embedding mapping WASM tokens to LM representation space
  • wasm_logit_bias: Learned bias controlling WASM token generation probability

Training

Parameter Value
GPU NVIDIA B200 (192GB HBM3e)
Optimizer MuonAdamW (Muon for matrix params, AdamW for scalars)
Precision FP8 (Blackwell native) with bf16 master weights
Phase Supervised Fine-Tuning (SFT)
Data WASM programs + text conversations (SmolTalk + MMLU)
Text ratio 30% pure text, 70% WASM conversations
Gradient checkpointing Enabled
Base model karpathy/nanochat-d34

Checkpoints

  • epoch_0/ β€” End of SFT epoch 0
  • epoch_1/ β€” End of SFT epoch 1
  • epoch_4_batch_6100/ β€” Latest SFT checkpoint (epoch 4, batch 6100)

All checkpoints are loadable via AutoModelForCausalLM.from_pretrained with trust_remote_code=True.

Note: This is an early preview. WASM coprocessor triggering reliability varies across checkpoints and prompts. Conversational (non-math) prompts work reliably. Math/WASM execution may not trigger consistently depending on the checkpoint and prompt phrasing.

How to Use

Important: This model was trained with integers in the prompt encoded as 4-byte WASM tokens (not as regular text). You must byte-encode numbers in your input and decode WASM output values back to integers for display.

Quick Start

import re
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

REPO = "eastlondoner/nanochat-wasm-fused-preview-01"
WASM_OFFSET = 65536
BYTE_OFFSET = 264

# ── 1. Load model + tokenizer ───────────────────────────────────
config = AutoConfig.from_pretrained(REPO, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    REPO, config=config, trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    subfolder="epoch_4_batch_6100",
)
model = model.to("cuda").eval()

tok = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True, use_fast=False)

# ── 2. Encode prompt (byte-encode integers) ─────────────────────
# encode_chat() automatically byte-encodes integers with i32.const prefix
prompt_ids = tok.encode_chat("What is 15 + 27?")

# ── 3. Generate ─────────────────────────────────────────────────
eos_id = tok._special_token_ids["<|assistant_end|>"]
generated, wasm_outputs, trace = model.generate_chat(
    prompt_ids,
    max_new_tokens=1024,
    temperature=0,
    return_outputs=True,
    eos_token_id=eos_id,
)

# ── 4. Decode response ──────────────────────────────────────────
# Filter text-only tokens for human-readable text
text_tokens = [t for t in generated[len(prompt_ids):] if 0 < t < WASM_OFFSET]
response = tok.decode(text_tokens)
print(f"Text: {response}")

# The coprocessor's computed results are in wasm_outputs (list of ints)
if wasm_outputs:
    print(f"Answer: {wasm_outputs[-1]}")  # β†’ 42

Input/Output Encoding

Inputs: All integers in the user prompt must be byte-encoded before tokenization. The tokenizer's encode_chat() method handles this automatically β€” it finds integer literals via regex and replaces each with an i32.const opcode token (65536) followed by 4 big-endian byte tokens in the WASM token range (65800–66055).

For example, "What is 15 + 27?" becomes:

text("What is ") + [i32.const, 0x00, 0x00, 0x00, 0x0F] + text(" + ") + [i32.const, 0x00, 0x00, 0x00, 0x1B] + text("?")

where i32.const is token 65536 and each byte b maps to token 65536 + 264 + b.

Outputs: The model generates a mix of text tokens (< 65536) and WASM tokens (β‰₯ 65536). The WASM coprocessor executes the WASM trace and returns integer results via wasm_outputs. These are plain Python integers ready for display. The text portion of the response (e.g., "The answer is") can be decoded normally; the actual numeric answer comes from wasm_outputs.

Conversational (Non-Math) Prompts

For prompts without numbers (e.g., "Hello, how are you?"), standard tokenization works fine β€” no byte encoding needed:

prompt_ids = tok.encode_chat("What is the capital of France?")
generated, _, _ = model.generate_chat(
    prompt_ids, max_new_tokens=1024, temperature=0.8,
    eos_token_id=eos_id,
)
text_tokens = [t for t in generated[len(prompt_ids):] if 0 < t < WASM_OFFSET]
print(tok.decode(text_tokens))

Token Contract

The model uses an extended vocabulary where tokens β‰₯ 65536 are WASM tokens:

Token Range Meaning
0–65535 Standard text tokens (tiktoken BPE)
65536–65799 WASM opcodes (e.g., i32.const=65536, i32.add=65537, output=65776, halt=65791)
65800–66055 Operand/value bytes: token = 65536 + 264 + byte_value (0–255)
  • Opcodes: i32.const (0x00), i32.add (0x01), output (0xF0), halt (0xFF), etc.
  • Operand bytes: Each integer is encoded as 4 big-endian bytes, each byte offset to wasm_offset + 264 + byte_value
  • Feedback tokens: REPL_RESULT (261), BRANCH_TAKEN (262), BRANCH_NOT_TAKEN (263)
  • WASM pad: Used during replay sequences for lockstep token alignment

Sequence Format

Input numbers are byte-encoded in the question; WASM output values appear as 4-byte tokens after the output opcode:

[<|bos|>] [<|user_start|>] [text + byte-encoded integers] [<|user_end|>] [<|assistant_start|>]
[WASM inst] [feedback] [WASM inst] [feedback] ...
[output] [4 result bytes]
[halt]
[answer text tokens]
[<|assistant_end|>]

The WASM Coprocessor

The frozen WASM coprocessor is a hand-compiled 23-layer transformer (8 functional layers + 15 FFN-only padding layers):

  • Layer 0 (13 heads): Opcode identification via one-hot matching
  • Layer 1 (1 head): Stack depth accumulation via sum-attention
  • Layer 2 (0 heads): Depth squaring (FFN only)
  • Layer 3 (8 heads): Bit extraction for AND/OR operations
  • Layer 4 (2 heads): Stack retrieval + full arithmetic FFN
  • Layer 5 (4 heads): Filesystem I/O via cross-attention
  • Layer 6 (1 head): Local variable retrieval
  • Layer 7 (4 heads): Memory load/store + branch gate
  • Layers 8–22 (0 heads each): FFN-only identity/padding layers

Supports 25 WASM operations including arithmetic, comparisons, memory, locals, filesystem I/O, and loops with conditional branching. 115/115 compliance tests pass at 100% accuracy.

Supported WASM Operations

Arithmetic & Logic

i32.const, i32.add, i32.sub, i32.mul, i32.and, i32.or

Comparisons

i32.eq, i32.ne, i32.lt_s, i32.gt_s, i32.le_s, i32.ge_s

Memory & Variables

i32.load, i32.store, local.get, local.set, local.tee

Filesystem I/O

fd_open, fd_read, fd_write, fd_close (4 fds, 32 bytes/file)

Control Flow

loop, end_loop, br_if (up to 256 iterations, nested loops)

Output & Termination

output, halt

Model Files

File Description
config.json Model configuration with auto_map
configuration_nanochat.py PretrainedConfig subclass
modeling_nanochat.py PreTrainedModel wrapping ComposedModel
composed_model.py Coprocessor-augmented ComposedModel with generation
gpt.py NanochatGPT language model (d34 architecture)
wasm_transformer.py Frozen WASM Interpreter Transformer
wasm_ops.py WASM opcodes, constants, instruction encoding
compile_weights.py Coprocessor weight compiler
tokenization_nanochat.py Custom AutoTokenizer-compatible tokenizer class
handler.py HF Inference API handler
tokenizer.pkl Tiktoken BPE tokenizer (65536 tokens)
model.safetensors Model weights in safetensors format

Related

License

MIT

Downloads last month
4,171
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support