""" BDH Core Model — faithful implementation of BDH-GPU architecture Based on: https://arxiv.org/abs/2509.26507 Official repo: https://github.com/pathwaycom/bdh Key architectural features implemented: - ReLU sparse activations (~5% neurons fire) - Hebbian synaptic state (constant-size, not KV-cache) - Linear O(T) attention - Scale-free graph topology emerging from ReLU-lowrank structure """ import torch import torch.nn as nn import torch.nn.functional as F import math class BDHConfig: def __init__( self, vocab_size=256, # byte-level n_layer=4, n_head=4, n_embd=128, block_size=256, dropout=0.0, ): self.vocab_size = vocab_size self.n_layer = n_layer self.n_head = n_head self.n_embd = n_embd self.block_size = block_size self.dropout = dropout self.head_size = n_embd // n_head class RoPE(nn.Module): """Rotary Position Embedding.""" def __init__(self, head_size, max_seq=4096): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, head_size, 2).float() / head_size)) self.register_buffer("inv_freq", inv_freq) self.max_seq = max_seq def forward(self, x, seq_len): t = torch.arange(seq_len, device=x.device).float() freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) # (T, head_size) cos = emb.cos()[None, None, :, :] # (1,1,T,hs) sin = emb.sin()[None, None, :, :] return cos, sin def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat([-x2, x1], dim=-1) def apply_rope(q, cos, sin): return q * cos + rotate_half(q) * sin class BDHAttention(nn.Module): """ BDH linear attention with Hebbian synaptic state. Core equation (BDH-GPU, eq. 8 from paper): σ_{t+1} = σ_t + η * (relu(Q_t)^T relu(K_t)) [Hebbian update] out_t = relu(Q_t) @ σ_t [read from state] Memory is O(n_embd * head_size) — CONSTANT regardless of sequence length. Compare to transformer KV-cache: O(T * head_size) — GROWS with T. """ def __init__(self, config: BDHConfig): super().__init__() self.n_head = config.n_head self.head_size = config.head_size self.n_embd = config.n_embd self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.rope = RoPE(config.head_size) # Hebbian learning rate (eta) self.eta = nn.Parameter(torch.ones(1) * 0.1) # Storage for activation captures (used by visualizer) self.last_q_activations = None self.last_k_activations = None self.last_hebbian_state = None def forward(self, x, sigma=None, capture=False): """ x: (B, T, C) sigma: (B, n_head, head_size, head_size) — Hebbian synaptic state Returns: out (B,T,C), new_sigma """ B, T, C = x.shape qkv = self.qkv(x) # (B,T,3C) q, k, v = qkv.split(self.n_embd, dim=-1) # Reshape to multi-head q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B,H,T,hs) k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # RoPE cos, sin = self.rope(q, T) q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) # BDH sparse activation — THIS is the key: ReLU creates ~5% sparsity q_sparse = F.relu(q) # (B,H,T,hs) k_sparse = F.relu(k) if capture: self.last_q_activations = q_sparse.detach().cpu() self.last_k_activations = k_sparse.detach().cpu() # Initialise Hebbian state if not provided if sigma is None: sigma = torch.zeros(B, self.n_head, self.head_size, self.head_size, device=x.device) # Linear (causal) accumulation + Hebbian update # For each position t: out_t = q_t @ sigma_t; sigma_{t+1} += eta * k_t^T v_t outs = [] for t in range(T): qt = q_sparse[:, :, t:t+1, :] # (B,H,1,hs) kt = k_sparse[:, :, t:t+1, :].transpose(-1, -2) # (B,H,hs,1) vt = v[:, :, t:t+1, :] # (B,H,1,hs) out_t = torch.matmul(qt, sigma) # (B,H,1,hs) — read outs.append(out_t) # Hebbian: strengthen co-active synapses sigma = sigma + self.eta * torch.matmul(kt, vt) # (B,H,hs,hs) if capture: self.last_hebbian_state = sigma.detach().cpu() out = torch.cat(outs, dim=2) # (B,H,T,hs) out = out.transpose(1, 2).contiguous().view(B, T, C) out = self.proj(out) return out, sigma class BDHBlock(nn.Module): """Single BDH layer: attention + MLP with sparse (ReLU) activations.""" def __init__(self, config: BDHConfig): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd, elementwise_affine=False) self.ln2 = nn.LayerNorm(config.n_embd, elementwise_affine=False) self.attn = BDHAttention(config) # MLP — note ReLU (sparse), not GELU (dense) self.mlp = nn.Sequential( nn.Linear(config.n_embd, 4 * config.n_embd, bias=False), nn.ReLU(), # ← sparse! ~5% fire nn.Linear(4 * config.n_embd, config.n_embd, bias=False), ) # Storage for MLP activation captures self.last_mlp_activations = None def forward(self, x, sigma=None, capture=False): attn_out, sigma = self.attn(self.ln1(x), sigma, capture=capture) x = x + attn_out h = self.ln2(x) # Capture intermediate MLP activations (after first linear + ReLU) mid = F.relu(self.mlp[0](h)) if capture: self.last_mlp_activations = mid.detach().cpu() out = self.mlp[2](mid) x = x + out return x, sigma class BDHModel(nn.Module): """Full BDH language model.""" def __init__(self, config: BDHConfig): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.blocks = nn.ModuleList([BDHBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, elementwise_affine=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.tok_emb.weight = self.lm_head.weight # weight tying self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def forward(self, idx, sigma_list=None, capture=False): B, T = idx.shape x = self.tok_emb(idx) if sigma_list is None: sigma_list = [None] * len(self.blocks) new_sigmas = [] for i, block in enumerate(self.blocks): x, sigma = block(x, sigma_list[i], capture=capture) new_sigmas.append(sigma) x = self.ln_f(x) logits = self.lm_head(x) return logits, new_sigmas def get_activation_stats(self, idx): """Run forward pass and collect activation sparsity per layer. KEY: We measure neurons that are STRICTLY NON-ZERO (|act| > 0). ReLU creates exact hard zeros → true sparsity. GELU never outputs exactly 0 → always ~100% non-zero. """ with torch.no_grad(): self.forward(idx, capture=True) stats = [] for i, block in enumerate(self.blocks): mlp_acts = block.last_mlp_activations # (B, T, 4*n_embd) if mlp_acts is not None: # Correct metric: fraction of neurons with non-zero output frac_active = (mlp_acts != 0).float().mean().item() stats.append({ "layer": i, "sparsity": 1.0 - frac_active, "frac_active": frac_active, "activations": mlp_acts[0].numpy(), # (T, 4*n_embd) }) return stats def get_hebbian_state(self, idx): """Run and return Hebbian synaptic states after processing idx.""" sigmas = [] with torch.no_grad(): _, sigma_list = self.forward(idx, capture=True) for i, block in enumerate(self.blocks): if block.attn.last_hebbian_state is not None: sigmas.append(block.attn.last_hebbian_state[0].numpy()) # (H, hs, hs) return sigmas # --------------------------------------------------------------------------- # Transformer baseline (for comparison) # --------------------------------------------------------------------------- class TransformerBlock(nn.Module): """Standard GPT-style transformer block with GELU (dense activations).""" def __init__(self, config: BDHConfig): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd) self.attn_qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) self.attn_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.mlp = nn.Sequential( nn.Linear(config.n_embd, 4 * config.n_embd, bias=False), nn.GELU(), # ← dense! ~100% neurons have non-zero output nn.Linear(4 * config.n_embd, config.n_embd, bias=False), ) self.last_mlp_activations = None def forward(self, x, capture=False): B, T, C = x.shape n_head = 4 head_size = C // n_head qkv = self.attn_qkv(self.ln1(x)) q, k, v = qkv.split(C, dim=-1) q = q.view(B, T, n_head, head_size).transpose(1, 2) k = k.view(B, T, n_head, head_size).transpose(1, 2) v = v.view(B, T, n_head, head_size).transpose(1, 2) # Standard O(T²) attention att = (q @ k.transpose(-2, -1)) * (head_size ** -0.5) att = F.softmax(att, dim=-1) out = att @ v out = out.transpose(1, 2).contiguous().view(B, T, C) out = self.attn_proj(out) x = x + out h = self.ln2(x) mid = self.mlp[1](self.mlp[0](h)) # after GELU if capture: self.last_mlp_activations = mid.detach().cpu() out = self.mlp[2](mid) x = x + out return x class TransformerModel(nn.Module): def __init__(self, config: BDHConfig): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Embedding(config.block_size, config.n_embd) self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def forward(self, idx, capture=False): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos) for block in self.blocks: x = block(x, capture=capture) x = self.ln_f(x) return self.lm_head(x) def get_activation_stats(self, idx): """GELU never outputs exactly 0 → ~100% of neurons always non-zero.""" with torch.no_grad(): self.forward(idx, capture=True) stats = [] for i, block in enumerate(self.blocks): acts = block.last_mlp_activations if acts is not None: # GELU: non-zero fraction should be ~100% frac_active = (acts != 0).float().mean().item() stats.append({ "layer": i, "sparsity": 1.0 - frac_active, "frac_active": frac_active, "activations": acts[0].numpy(), }) return stats