DakshBeniwal111 commited on
Commit
694ce87
Β·
verified Β·
1 Parent(s): 9cf446e

Delete streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +0 -761
streamlit_app.py DELETED
@@ -1,761 +0,0 @@
1
- import streamlit as st
2
-
3
- st.set_page_config(
4
- page_title="BDH Sparse Brain",
5
- page_icon="πŸ‰",
6
- layout="wide",
7
- initial_sidebar_state="collapsed",
8
- )
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- import numpy as np
13
- import matplotlib
14
- matplotlib.use("Agg")
15
- import matplotlib.pyplot as plt
16
- import matplotlib.gridspec as gridspec
17
- from bdh_core import BDHModel, BDHConfig, TransformerModel
18
-
19
- # ══════════════════════════════════════════════════════════════════════════════
20
- # GLOBAL CSS β€” cinematic dark-lab aesthetic
21
- # ══════════════════════════════════════════════════════════════════════════════
22
- st.markdown("""
23
- <style>
24
- @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Outfit:wght@300;400;600;800&display=swap');
25
-
26
- *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
27
-
28
- html, body, .stApp {
29
- background: #05080f !important;
30
- color: #d4dce8;
31
- font-family: 'Outfit', sans-serif;
32
- }
33
-
34
- /* kill streamlit chrome */
35
- #MainMenu, footer, header { visibility: hidden !important; }
36
- .block-container {
37
- padding: 2rem 2.5rem !important;
38
- max-width: 1300px !important;
39
- }
40
-
41
- /* ── sidebar ── */
42
- section[data-testid="stSidebar"] {
43
- background: #080c18 !important;
44
- border-right: 1px solid #131c30;
45
- }
46
- section[data-testid="stSidebar"] * { color: #a8b8cc !important; }
47
- section[data-testid="stSidebar"] h1,
48
- section[data-testid="stSidebar"] h2,
49
- section[data-testid="stSidebar"] h3 { color: #e8734a !important; }
50
-
51
- /* ── typography ── */
52
- h1, h2, h3 {
53
- font-family: 'Outfit', sans-serif !important;
54
- font-weight: 800 !important;
55
- color: #f0f4fa !important;
56
- letter-spacing: -0.02em;
57
- }
58
-
59
- /* ── inputs ── */
60
- textarea, .stTextArea textarea {
61
- background: #0c1220 !important;
62
- color: #d4dce8 !important;
63
- border: 1px solid #1e2d45 !important;
64
- border-radius: 10px !important;
65
- font-family: 'Space Mono', monospace !important;
66
- font-size: 0.82rem !important;
67
- resize: none !important;
68
- }
69
- textarea:focus { border-color: #e8734a !important; outline: none !important; box-shadow: 0 0 0 3px rgba(232,115,74,0.15) !important; }
70
-
71
- /* ── sliders ── */
72
- .stSlider [data-baseweb="slider"] { padding: 0.3rem 0; }
73
-
74
- /* ── tabs ── */
75
- .stTabs [data-baseweb="tab-list"] {
76
- background: transparent !important;
77
- border-bottom: 1px solid #131c30;
78
- gap: 0;
79
- }
80
- .stTabs [data-baseweb="tab"] {
81
- background: transparent !important;
82
- color: #5a7a99 !important;
83
- font-family: 'Outfit', sans-serif !important;
84
- font-weight: 600 !important;
85
- font-size: 0.88rem !important;
86
- padding: 0.6rem 1.2rem !important;
87
- border: none !important;
88
- border-bottom: 2px solid transparent !important;
89
- }
90
- .stTabs [aria-selected="true"] {
91
- color: #e8734a !important;
92
- border-bottom: 2px solid #e8734a !important;
93
- background: transparent !important;
94
- }
95
- .stTabs [data-baseweb="tab-highlight"] { display: none !important; }
96
- .stTabs [data-baseweb="tab-panel"] { padding-top: 1.5rem !important; }
97
-
98
- /* ── buttons ── */
99
- .stButton > button {
100
- background: linear-gradient(135deg, #e8734a, #c94f2a) !important;
101
- color: white !important;
102
- border: none !important;
103
- border-radius: 10px !important;
104
- font-family: 'Outfit', sans-serif !important;
105
- font-weight: 600 !important;
106
- padding: 0.6rem 1.6rem !important;
107
- letter-spacing: 0.02em;
108
- transition: opacity 0.2s !important;
109
- }
110
- .stButton > button:hover { opacity: 0.88 !important; }
111
-
112
- /* ── custom components ── */
113
- .page-header {
114
- padding: 2.5rem 0 2rem;
115
- border-bottom: 1px solid #131c30;
116
- margin-bottom: 2rem;
117
- }
118
- .page-header .eyebrow {
119
- font-family: 'Space Mono', monospace;
120
- font-size: 0.72rem;
121
- color: #e8734a;
122
- letter-spacing: 0.18em;
123
- text-transform: uppercase;
124
- margin-bottom: 0.5rem;
125
- }
126
- .page-header h1 {
127
- font-size: 2.8rem !important;
128
- line-height: 1.0 !important;
129
- background: linear-gradient(135deg, #f0f4fa 0%, #e8734a 100%);
130
- -webkit-background-clip: text;
131
- background-clip: text;
132
- color: transparent !important;
133
- margin-bottom: 0.6rem;
134
- }
135
- .page-header .sub {
136
- color: #5a7a99;
137
- font-size: 1rem;
138
- font-weight: 300;
139
- max-width: 620px;
140
- }
141
-
142
- .stat-grid { display: grid; grid-template-columns: repeat(4, 1fr); gap: 1rem; margin: 1.8rem 0; }
143
- .stat-card {
144
- background: #080c18;
145
- border: 1px solid #131c30;
146
- border-radius: 14px;
147
- padding: 1.2rem 1rem;
148
- position: relative;
149
- overflow: hidden;
150
- }
151
- .stat-card::before {
152
- content: '';
153
- position: absolute;
154
- top: 0; left: 0; right: 0;
155
- height: 2px;
156
- background: linear-gradient(90deg, #e8734a, transparent);
157
- }
158
- .stat-card.blue::before { background: linear-gradient(90deg, #3b7dd8, transparent); }
159
- .stat-card .val {
160
- font-family: 'Space Mono', monospace;
161
- font-size: 1.9rem;
162
- font-weight: 700;
163
- color: #e8734a;
164
- line-height: 1.1;
165
- }
166
- .stat-card.blue .val { color: #3b7dd8; }
167
- .stat-card .lbl {
168
- font-size: 0.76rem;
169
- color: #5a7a99;
170
- margin-top: 0.4rem;
171
- font-weight: 400;
172
- letter-spacing: 0.02em;
173
- }
174
- .stat-card .icon { font-size: 1.1rem; margin-bottom: 0.4rem; }
175
-
176
- .insight {
177
- background: #080c18;
178
- border-left: 3px solid #e8734a;
179
- border-radius: 0 10px 10px 0;
180
- padding: 1rem 1.2rem;
181
- margin: 0.8rem 0;
182
- font-size: 0.88rem;
183
- color: #a8b8cc;
184
- line-height: 1.6;
185
- }
186
- .insight b { color: #f0f4fa; }
187
-
188
- .section-label {
189
- font-family: 'Space Mono', monospace;
190
- font-size: 0.68rem;
191
- color: #e8734a;
192
- letter-spacing: 0.15em;
193
- text-transform: uppercase;
194
- margin-bottom: 0.8rem;
195
- }
196
-
197
- .badge {
198
- display: inline-block;
199
- padding: 3px 12px;
200
- border-radius: 999px;
201
- font-size: 0.75rem;
202
- font-weight: 600;
203
- font-family: 'Space Mono', monospace;
204
- margin-bottom: 0.6rem;
205
- }
206
- .badge-orange { background: rgba(232,115,74,0.15); color: #e8734a; border: 1px solid rgba(232,115,74,0.3); }
207
- .badge-blue { background: rgba(59,125,216,0.15); color: #3b7dd8; border: 1px solid rgba(59,125,216,0.3); }
208
-
209
- .output-box {
210
- background: #080c18;
211
- border: 1px solid #131c30;
212
- border-radius: 12px;
213
- padding: 1rem 1.2rem;
214
- font-family: 'Space Mono', monospace;
215
- font-size: 0.78rem;
216
- color: #a8b8cc;
217
- min-height: 60px;
218
- word-break: break-all;
219
- line-height: 1.6;
220
- }
221
- .loss-tag {
222
- font-family: 'Space Mono', monospace;
223
- font-size: 0.8rem;
224
- color: #5a7a99;
225
- margin-top: 0.5rem;
226
- }
227
- .loss-tag span { color: #e8734a; }
228
-
229
- .divider { border: none; border-top: 1px solid #131c30; margin: 2rem 0; }
230
- </style>
231
- """, unsafe_allow_html=True)
232
-
233
- # ── Plot theme constants ──────────────────────────────────────────────────────
234
- BG = "#05080f"
235
- CARD = "#080c18"
236
- GRID = "#131c30"
237
- TICK = "#3a5070"
238
- ORNG = "#e8734a"
239
- BLUE = "#3b7dd8"
240
- TEXT = "#d4dce8"
241
- MUTE = "#5a7a99"
242
-
243
- def _ax(fig, axes):
244
- fig.patch.set_facecolor(BG)
245
- for ax in (axes if hasattr(axes, '__iter__') else [axes]):
246
- ax.set_facecolor(CARD)
247
- ax.tick_params(colors=TICK, labelsize=8)
248
- for s in ax.spines.values():
249
- s.set_color(GRID)
250
- ax.xaxis.label.set_color(MUTE)
251
- ax.yaxis.label.set_color(MUTE)
252
-
253
- # ── Model loading ─────────────────────────────────────────────────────────────
254
- @st.cache_resource(show_spinner=False)
255
- def load_models():
256
- cfg = BDHConfig(vocab_size=256, n_layer=4, n_head=4, n_embd=128)
257
- bdh = BDHModel(cfg).eval()
258
- tf = TransformerModel(cfg).eval()
259
- return bdh, tf, cfg
260
-
261
- def tokenise(text, max_len=64):
262
- t = [min(b, 255) for b in text.encode()][:max_len]
263
- if len(t) < 2: t += [32] * (2 - len(t))
264
- return torch.tensor([t], dtype=torch.long)
265
-
266
- # ── Chart builders ────────────────────────────────────────────────────────────
267
- @st.cache_data(show_spinner=False)
268
- def chart_bar(bdh_vals, tf_vals):
269
- n = len(bdh_vals)
270
- x = np.arange(n)
271
- w = 0.32
272
- fig, ax = plt.subplots(figsize=(8, 3.4), facecolor=BG)
273
- b1 = ax.bar(x - w/2, bdh_vals, w, color=ORNG, alpha=0.9, zorder=3, label="BDH (ReLU)")
274
- b2 = ax.bar(x + w/2, tf_vals, w, color=BLUE, alpha=0.9, zorder=3, label="Transformer (GELU)")
275
- ax.axhline(5, color=ORNG, ls="--", lw=1.1, alpha=0.45)
276
- ax.axhline(100, color=BLUE, ls=":", lw=1.1, alpha=0.25)
277
- ax.set_xticks(x); ax.set_xticklabels([f"L{i}" for i in x], color=TICK)
278
- ax.set_ylim(0, 115); ax.yaxis.grid(True, color=GRID, zorder=0); ax.set_axisbelow(True)
279
- ax.set_title("Active Neurons per Layer (%)", color=TEXT, fontsize=10, fontweight="bold", pad=10, fontfamily="monospace")
280
- _ax(fig, [ax])
281
- ax.legend(facecolor=CARD, edgecolor=GRID, labelcolor=TEXT, fontsize=8, framealpha=0.9)
282
- for bar, c in [(b1, ORNG), (b2, BLUE)]:
283
- for b in bar:
284
- ax.text(b.get_x()+b.get_width()/2, b.get_height()+1.8,
285
- f"{b.get_height():.0f}%", ha="center", va="bottom",
286
- color=c, fontsize=7.5, fontweight="bold", fontfamily="monospace")
287
- fig.tight_layout(pad=1.2)
288
- return fig
289
-
290
- @st.cache_data(show_spinner=False)
291
- def chart_heatmap(data_bytes, title, cmap):
292
- data = np.frombuffer(data_bytes, dtype=np.float32).reshape(-1, 64)
293
- fig, ax = plt.subplots(figsize=(7, 2.8), facecolor=BG)
294
- vmin, vmax = float(np.min(data)), float(np.max(data))
295
- if np.isclose(vmin, vmax): vmax = vmin + 1e-6
296
- im = ax.imshow(data.T, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
297
- ax.set_xlabel("Token β†’", color=MUTE, fontsize=8)
298
- ax.set_ylabel("Neuron β†’", color=MUTE, fontsize=8)
299
- ax.set_title(title, color=TEXT, fontsize=9, fontweight="bold", pad=8, fontfamily="monospace")
300
- _ax(fig, [ax])
301
- cb = fig.colorbar(im, ax=ax, fraction=0.022, pad=0.02)
302
- cb.ax.tick_params(colors=TICK, labelsize=7)
303
- plt.setp(cb.ax.get_yticklabels(), color=TICK)
304
- fig.tight_layout(pad=1.2)
305
- return fig
306
-
307
- @st.cache_data(show_spinner=False)
308
- def chart_memory():
309
- T = np.arange(0, 110_000, 400)
310
- hs, nh, nl, db = 32, 4, 4, 2
311
- bdh_m = np.full(len(T), nl*nh*hs**2*db/1e6, dtype=float)
312
- tf_m = T * 2*nh*hs*db / 1e6
313
- fig, ax = plt.subplots(figsize=(9, 3.4), facecolor=BG)
314
- ax.fill_between(T/1000, bdh_m, alpha=0.10, color=ORNG)
315
- ax.fill_between(T/1000, tf_m, alpha=0.10, color=BLUE)
316
- ax.plot(T/1000, bdh_m, color=ORNG, lw=2.2, label="BDH β€” O(1) Hebbian state")
317
- ax.plot(T/1000, tf_m, color=BLUE, lw=2.2, label="Transformer β€” O(T) KV-cache")
318
- ax.axvline(12, color="#e05252", ls="--", lw=1.4)
319
- ax.text(13.5, tf_m.max()*0.62, "⚠ OOM\n~12k", color="#e05252", fontsize=8, fontweight="bold", fontfamily="monospace")
320
- ax.annotate("BDH flat\nat 50k+ βœ“", xy=(50, bdh_m[0]), xytext=(60, bdh_m[0]+0.07),
321
- color=ORNG, fontsize=8, fontweight="bold", fontfamily="monospace",
322
- arrowprops=dict(arrowstyle="->", color=ORNG, lw=1.2))
323
- ax.set_xlabel("Sequence length (k tokens)", color=MUTE, fontsize=9)
324
- ax.set_ylabel("Memory (MB)", color=MUTE, fontsize=9)
325
- ax.set_title("Memory Scaling: O(1) vs O(T)", color=TEXT, fontsize=10, fontweight="bold", pad=10, fontfamily="monospace")
326
- _ax(fig, [ax]); ax.yaxis.grid(True, color=GRID); ax.set_axisbelow(True)
327
- ax.legend(facecolor=CARD, edgecolor=GRID, labelcolor=TEXT, fontsize=9, framealpha=0.9)
328
- fig.tight_layout(pad=1.2)
329
- return fig
330
-
331
- def chart_hebbian(sigma_list, layer):
332
- if not sigma_list or layer >= len(sigma_list):
333
- return None
334
- sigma = sigma_list[layer]
335
- H = sigma.shape[0]
336
- fig, axes = plt.subplots(1, H, figsize=(10, 2.6), facecolor=BG)
337
- if H == 1: axes = [axes]
338
- for h, ax in enumerate(axes):
339
- m = sigma[h]; vabs = np.abs(m).max()+1e-8
340
- im = ax.imshow(m, cmap="RdBu_r", vmin=-vabs, vmax=vabs, interpolation="nearest")
341
- ax.set_title(f"Head {h}", color="#fdba74", fontsize=9, fontfamily="monospace")
342
- ax.set_facecolor(BG)
343
- ax.tick_params(colors=TICK, labelsize=6)
344
- for s in ax.spines.values(): s.set_color(GRID)
345
- fig.suptitle(f"Hebbian Synaptic State Οƒ β€” Layer {layer}",
346
- color=TEXT, fontsize=9, fontweight="bold", fontfamily="monospace")
347
- fig.tight_layout(pad=1.0)
348
- return fig
349
-
350
- def chart_topology(bdh_model):
351
- w = bdh_model.blocks[0].attn.qkv.weight.detach().cpu().numpy()
352
- fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(11, 3.4), facecolor=BG)
353
- im = ax0.imshow(np.abs(w[:64,:64]), cmap="inferno", interpolation="nearest")
354
- ax0.set_title("QKV Weight |W| β€” hub structure", color=TEXT, fontsize=9, fontweight="bold", pad=8, fontfamily="monospace")
355
- fig.colorbar(im, ax=ax0, fraction=0.04)
356
- norms = np.linalg.norm(w, axis=0)
357
- ax1.hist(norms, bins=40, color=ORNG, alpha=0.88, edgecolor=BG)
358
- ax1.set_xlabel("Column norm (hub-ness)", color=MUTE, fontsize=8)
359
- ax1.set_ylabel("Count", color=MUTE, fontsize=8)
360
- ax1.set_title("Hub Degree Distribution\n(heavy tail = scale-free)", color=TEXT, fontsize=9, fontweight="bold", pad=8, fontfamily="monospace")
361
- ax1.yaxis.grid(True, color=GRID); ax1.set_axisbelow(True)
362
- _ax(fig, [ax0, ax1]); fig.tight_layout(pad=1.2)
363
- return fig
364
-
365
- def chart_neuron_bar(acts, top_idx):
366
- top_val = acts[top_idx]
367
- colors = [ORNG if v > 0 else BLUE for v in top_val]
368
- fig, ax = plt.subplots(figsize=(8, 2.8), facecolor=BG)
369
- ax.bar([f"N{n}" for n in top_idx], top_val, color=colors, zorder=3)
370
- ax.axhline(0, color=GRID, lw=0.8)
371
- ax.set_title("Top Neuron Activations β€” BDH (sparse β†’ interpretable)", color=TEXT,
372
- fontsize=9, fontweight="bold", pad=8, fontfamily="monospace")
373
- ax.tick_params(colors=TICK, labelrotation=40, labelsize=8)
374
- ax.yaxis.grid(True, color=GRID); ax.set_axisbelow(True)
375
- _ax(fig, [ax]); fig.tight_layout(pad=1.2)
376
- return fig
377
-
378
- # ── Generate text helper ──────────────────────────────────────────────────────
379
- @torch.no_grad()
380
- def generate(model, idx, n=35, temp=1.0, top_k=10, is_bdh=False):
381
- out = idx.clone()
382
- for _ in range(n):
383
- logits = model(out)[0] if is_bdh else model(out)
384
- logits = logits[:, -1, :] / temp
385
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
386
- logits[logits < v[:, [-1]]] = float('-inf')
387
- out = torch.cat([out, torch.multinomial(F.softmax(logits, dim=-1), 1)], dim=1)
388
- return out
389
-
390
- # ══════════════════════════════════════════════════════════════════════════════
391
- # MAIN
392
- # ════��═════════════════════════════════════════════════════════════════════════
393
- def main():
394
- bdh_model, tf_model, cfg = load_models()
395
-
396
- # ── Header ──
397
- st.markdown("""
398
- <div class="page-header">
399
- <div class="eyebrow">Post-Transformer Hackathon Β· Pathway Γ— IIT Ropar</div>
400
- <h1>πŸ‰ BDH Sparse Brain Visualizer</h1>
401
- <div class="sub">
402
- Interactive exploration of sparse neural computation, Hebbian memory &amp;
403
- interpretable activations in the Dragon Hatchling architecture.
404
- </div>
405
- </div>
406
- """, unsafe_allow_html=True)
407
-
408
- # ── Controls row ──
409
- col_in, col_layer = st.columns([3, 1])
410
- with col_in:
411
- input_text = st.text_area(
412
- "Input text",
413
- value="The dragon hatchling thinks with sparse neurons that fire together and wire together.",
414
- height=90, label_visibility="collapsed"
415
- )
416
- with col_layer:
417
- st.markdown("<div style='height:0.3rem'></div>", unsafe_allow_html=True)
418
- layer_idx = st.slider("Hebbian layer", 0, cfg.n_layer - 1, 0)
419
- st.markdown(f"<div style='font-family:Space Mono;font-size:0.7rem;color:{MUTE};margin-top:0.3rem'>layer {layer_idx} selected</div>", unsafe_allow_html=True)
420
-
421
- # ── Run models (cached via session state to prevent shaking) ──
422
- tok_key = input_text[:80]
423
- if "last_tok_key" not in st.session_state or st.session_state.last_tok_key != tok_key:
424
- tokens = tokenise(input_text)
425
- T = tokens.shape[1]
426
- with torch.no_grad():
427
- bdh_stats = bdh_model.get_activation_stats(tokens)
428
- tf_stats = tf_model.get_activation_stats(tokens)
429
- sigma_list = bdh_model.get_hebbian_state(tokens)
430
- bdh_logits, _ = bdh_model(tokens)
431
- tf_logits = tf_model(tokens)
432
- tgt = torch.cat([tokens[:, 1:], tokens[:, -1:]], dim=1)
433
- bdh_loss = F.cross_entropy(bdh_logits.reshape(-1, cfg.vocab_size), tgt.reshape(-1)).item()
434
- tf_loss = F.cross_entropy(tf_logits.reshape(-1, cfg.vocab_size), tgt.reshape(-1)).item()
435
- bdh_out = generate(bdh_model, tokens, is_bdh=True)
436
- tf_out = generate(tf_model, tokens, is_bdh=False)
437
- st.session_state.update({
438
- "last_tok_key": tok_key,
439
- "bdh_stats": bdh_stats, "tf_stats": tf_stats,
440
- "sigma_list": sigma_list, "T": T,
441
- "bdh_loss": bdh_loss, "tf_loss": tf_loss,
442
- "bdh_text": bytes(bdh_out.squeeze(0).tolist()).decode(errors="replace"),
443
- "tf_text": bytes(tf_out.squeeze(0).tolist()).decode(errors="replace"),
444
- })
445
-
446
- ss = st.session_state
447
- bdh_stats = ss["bdh_stats"]
448
- tf_stats = ss["tf_stats"]
449
- sigma_list = ss["sigma_list"]
450
- T = ss["T"]
451
-
452
- avg_bdh = np.mean([s["frac_active"] for s in bdh_stats]) * 100
453
- avg_tf = np.mean([s["frac_active"] for s in tf_stats]) * 100
454
- hebb_kb = (cfg.n_layer * cfg.n_head * cfg.head_size**2 * 2) / 1024
455
- kv_kb = (T * 2 * cfg.n_head * cfg.head_size * 2) / 1024
456
-
457
- # ── Stat cards ──
458
- st.markdown(f"""
459
- <div class="stat-grid">
460
- <div class="stat-card">
461
- <div class="icon">πŸ‰</div>
462
- <div class="val">{avg_bdh:.1f}%</div>
463
- <div class="lbl">BDH Neurons Active</div>
464
- </div>
465
- <div class="stat-card blue">
466
- <div class="icon">πŸ€–</div>
467
- <div class="val">{avg_tf:.1f}%</div>
468
- <div class="lbl">Transformer Neurons Active</div>
469
- </div>
470
- <div class="stat-card">
471
- <div class="icon">⚑</div>
472
- <div class="val">{hebb_kb:.0f} KB</div>
473
- <div class="lbl">BDH Memory (constant)</div>
474
- </div>
475
- <div class="stat-card blue">
476
- <div class="icon">πŸ“ˆ</div>
477
- <div class="val">{kv_kb:.0f} KB</div>
478
- <div class="lbl">Transformer KV-Cache (grows)</div>
479
- </div>
480
- </div>
481
- """, unsafe_allow_html=True)
482
-
483
- st.markdown(f"""
484
- <div style="text-align:center;padding:0.6rem 0 1.4rem;font-family:'Space Mono',monospace;font-size:0.82rem;color:{MUTE};">
485
- Processing <span style="color:{TEXT};font-weight:700">{T} tokens</span> &nbsp;Β·&nbsp;
486
- BDH <span style="color:{ORNG};font-weight:700">{avg_bdh:.1f}%</span> active
487
- &nbsp;vs&nbsp;
488
- Transformer <span style="color:{BLUE};font-weight:700">{avg_tf:.1f}%</span> active
489
- &nbsp;Β·&nbsp; <span style="color:{MUTE}">untrained model β€” sparsity increases after training</span>
490
- </div>
491
- """, unsafe_allow_html=True)
492
-
493
- # ── Output comparison ──
494
- st.markdown("<hr class='divider'>", unsafe_allow_html=True)
495
- st.markdown("<div class='section-label'>Model Output Comparison</div>", unsafe_allow_html=True)
496
- oc1, oc2 = st.columns(2)
497
- with oc1:
498
- st.markdown("<div class='badge badge-orange'>πŸ‰ BDH Output</div>", unsafe_allow_html=True)
499
- st.markdown(f"<div class='output-box'>{ss['bdh_text'][:300]}</div>", unsafe_allow_html=True)
500
- st.markdown(f"<div class='loss-tag'>cross-entropy loss: <span>{ss['bdh_loss']:.4f}</span></div>", unsafe_allow_html=True)
501
- with oc2:
502
- st.markdown("<div class='badge badge-blue'>πŸ€– Transformer Output</div>", unsafe_allow_html=True)
503
- st.markdown(f"<div class='output-box'>{ss['tf_text'][:300]}</div>", unsafe_allow_html=True)
504
- st.markdown(f"<div class='loss-tag'>cross-entropy loss: <span style='color:{BLUE}'>{ss['tf_loss']:.4f}</span></div>", unsafe_allow_html=True)
505
-
506
- st.markdown("<hr class='divider'>", unsafe_allow_html=True)
507
-
508
- # ── Tabs ──
509
- tab1, tab2, tab3, tab4, tab5 = st.tabs([
510
- "⚑ Activation Sparsity",
511
- "🧠 Hebbian Memory",
512
- "πŸ“ˆ Memory Scaling",
513
- "🌐 Graph Topology",
514
- "πŸ”₯ Live Training",
515
- ])
516
-
517
- # ─────────────────────────────────── TAB 1 ───
518
- with tab1:
519
- st.markdown("""
520
- <div class="insight">
521
- <b>Core BDH insight:</b> BDH uses <b>ReLU</b> activations β€” hard-zeros all negative values β†’ natural ~5% sparsity.
522
- Transformers use <b>GELU</b> which never outputs exactly zero β†’ ~100% active. Same input. Dramatically different neural behaviour.
523
- </div>""", unsafe_allow_html=True)
524
-
525
- bdh_vals = tuple(s["frac_active"]*100 for s in bdh_stats)
526
- tf_vals = tuple(s["frac_active"]*100 for s in tf_stats)
527
- fig = chart_bar(bdh_vals, tf_vals)
528
- st.pyplot(fig, use_container_width=True); plt.close(fig)
529
-
530
- st.markdown("<div class='section-label' style='margin-top:1.5rem'>Activation Heatmaps β€” Layer 0</div>", unsafe_allow_html=True)
531
- hc1, hc2 = st.columns(2)
532
-
533
- acts_bdh = bdh_stats[0]["activations"]
534
- acts_tf = tf_stats[0]["activations"]
535
- data_bdh = acts_bdh[:, :64].astype(np.float32)
536
- data_tf = acts_tf[:, :64].astype(np.float32)
537
-
538
- with hc1:
539
- st.markdown("<div class='badge badge-orange'>πŸ‰ BDH β€” ReLU sparse</div>", unsafe_allow_html=True)
540
- fig = chart_heatmap(data_bdh.tobytes(),
541
- f"BDH L0 β€” {bdh_stats[0]['frac_active']*100:.1f}% active", "hot")
542
- st.pyplot(fig, use_container_width=True); plt.close(fig)
543
- with hc2:
544
- st.markdown("<div class='badge badge-blue'>πŸ€– Transformer β€” GELU dense</div>", unsafe_allow_html=True)
545
- fig = chart_heatmap(data_tf.tobytes(),
546
- f"Transformer L0 β€” {tf_stats[0]['frac_active']*100:.1f}% active", "Blues")
547
- st.pyplot(fig, use_container_width=True); plt.close(fig)
548
-
549
- st.markdown("<div class='section-label' style='margin-top:1.5rem'>Per-Layer Summary</div>", unsafe_allow_html=True)
550
- cols = st.columns(len(bdh_stats))
551
- for i, (bs, ts) in enumerate(zip(bdh_stats, tf_stats)):
552
- with cols[i]:
553
- st.metric(f"Layer {i}",
554
- f"BDH {bs['frac_active']*100:.1f}%",
555
- delta=f"TF {ts['frac_active']*100:.1f}%")
556
-
557
- st.markdown("<hr class='divider'>", unsafe_allow_html=True)
558
- st.markdown("<div class='section-label'>Neuron Inspector</div>", unsafe_allow_html=True)
559
- nc1, nc2 = st.columns(2)
560
- with nc1:
561
- l_sel = st.select_slider("Layer", options=list(range(len(bdh_stats))), value=0, key="ni_l")
562
- with nc2:
563
- max_tok = bdh_stats[0]["activations"].shape[0] - 1
564
- t_sel = st.select_slider("Token position", options=list(range(max_tok+1)), value=0, key="ni_t")
565
-
566
- acts = bdh_stats[l_sel]["activations"][t_sel]
567
- top_idx = np.argsort(np.abs(acts))[-12:]
568
- toks_list = list(input_text.encode("utf-8"))
569
- byte_val = toks_list[t_sel] if t_sel < len(toks_list) else 63
570
- char_repr = chr(byte_val) if 32 <= byte_val < 127 else "Β·"
571
- st.markdown(f"""
572
- <div style="font-family:'Space Mono',monospace;font-size:0.78rem;color:{MUTE};margin-bottom:0.8rem">
573
- token <span style="color:{TEXT}">{t_sel}</span> β†’
574
- byte <span style="color:{ORNG}">{byte_val}</span>
575
- (<span style="color:{TEXT}">{char_repr!r}</span>)
576
- &nbsp;Β·&nbsp; {(acts>0).sum()} / {len(acts)} neurons firing
577
- </div>""", unsafe_allow_html=True)
578
- fig = chart_neuron_bar(acts, top_idx)
579
- st.pyplot(fig, use_container_width=True); plt.close(fig)
580
- st.markdown("""
581
- <div class="insight" style="margin-top:0.8rem">
582
- Because BDH activates only ~5% of neurons per token, you can point to exactly which neurons matter for each prediction.
583
- This is <b>built-in interpretability</b> β€” transformer dense activations make this kind of inspection practically impossible.
584
- </div>""", unsafe_allow_html=True)
585
-
586
- # ─────────────────────────────────── TAB 2 ───
587
- with tab2:
588
- st.markdown("""
589
- <div class="insight">
590
- <b>"Neurons that fire together, wire together."</b> β€” Hebb's rule<br><br>
591
- BDH maintains a fixed-size synaptic state matrix <b>Οƒ</b> that strengthens when neurons co-activate.
592
- Memory size is <b>constant</b> β€” O(n_head Γ— head_sizeΒ²) β€” regardless of sequence length.
593
- </div>""", unsafe_allow_html=True)
594
-
595
- fig = chart_hebbian(sigma_list, layer=layer_idx)
596
- if fig:
597
- st.pyplot(fig, use_container_width=True); plt.close(fig)
598
-
599
- hb1, hb2 = st.columns(2)
600
- with hb1:
601
- st.markdown(f"""
602
- <div style="font-family:'Space Mono',monospace;font-size:0.8rem;line-height:1.8;color:{MUTE}">
603
- <span style="color:{TEXT}">Each cell (i,j)</span> = synapse between neuron i and j<br>
604
- πŸ”΄ Red = excitatory connection<br>
605
- πŸ”΅ Blue = inhibitory connection<br>
606
- βšͺ White = weak / no connection
607
- </div>""", unsafe_allow_html=True)
608
- with hb2:
609
- st.markdown(f"""
610
- <div style="font-family:'Space Mono',monospace;font-size:0.8rem;line-height:1.8;color:{MUTE}">
611
- BDH Hebbian state: <span style="color:{ORNG}">{hebb_kb:.0f} KB</span> (fixed forever)<br>
612
- Transformer at {T} tokens: <span style="color:{BLUE}">{kv_kb:.0f} KB</span><br>
613
- Transformer at 50k tokens: <span style="color:#e05252">{(50000*2*cfg.n_head*cfg.head_size*2)//1024} KB</span>
614
- </div>""", unsafe_allow_html=True)
615
-
616
- st.markdown("<div class='section-label' style='margin-top:1.5rem'>All Layers</div>", unsafe_allow_html=True)
617
- for li in range(len(sigma_list)):
618
- with st.expander(f"Layer {li}"):
619
- fig = chart_hebbian(sigma_list, layer=li)
620
- if fig:
621
- st.pyplot(fig, use_container_width=True); plt.close(fig)
622
-
623
- # ─────────────────────────────────── TAB 3 ───
624
- with tab3:
625
- st.markdown("""
626
- <div class="insight">
627
- Transformer KV-caches grow linearly with every token β€” eventually crashing the GPU.
628
- BDH's Hebbian state is <b>constant size forever</b>. Community experiments confirm BDH running 50k+ tokens
629
- with flat memory while transformers OOM at ~12k on identical hardware.
630
- </div>""", unsafe_allow_html=True)
631
-
632
- fig = chart_memory()
633
- st.pyplot(fig, use_container_width=True); plt.close(fig)
634
-
635
- mc1, mc2, mc3 = st.columns(3)
636
- for col, v, l in [(mc1,"O(1)","BDH complexity"), (mc2,"O(T)","Transformer complexity"), (mc3,"50k+","Max tokens (BDH)")]:
637
- with col:
638
- st.markdown(f"""<div class="stat-card" style="text-align:center">
639
- <div class="val">{v}</div><div class="lbl">{l}</div></div>""", unsafe_allow_html=True)
640
-
641
- st.markdown(f"""
642
- <div style="margin-top:1.5rem;font-family:'Space Mono',monospace;font-size:0.8rem;color:{MUTE};line-height:2">
643
- Applications unlocked:<br>
644
- Healthcare β€” full patient history in context &nbsp;Β·&nbsp;
645
- Legal β€” entire contracts reasoned at once &nbsp;Β·&nbsp;
646
- Research β€” thousands of papers synthesised &nbsp;Β·&nbsp;
647
- Code β€” large codebases in one pass
648
- </div>""", unsafe_allow_html=True)
649
-
650
- # ─────────────────────────────────── TAB 4 ───
651
- with tab4:
652
- st.markdown("""
653
- <div class="insight">
654
- BDH weight matrices form <b>scale-free networks</b> β€” a few hub neurons connect broadly (like brain hubs),
655
- most connect sparsely. This structure emerges from ReLU-lowrank dynamics and is the architectural
656
- basis for monosemantic synapses.
657
- </div>""", unsafe_allow_html=True)
658
-
659
- fig = chart_topology(bdh_model)
660
- st.pyplot(fig, use_container_width=True); plt.close(fig)
661
-
662
- tc1, tc2 = st.columns(2)
663
- with tc1:
664
- st.markdown(f"""<div class="insight">
665
- <b>In neuroscience:</b> biological neural connectivity follows power-law distributions with hub nodes.
666
- BDH replicates this naturally β€” transformers do not.
667
- </div>""", unsafe_allow_html=True)
668
- with tc2:
669
- st.markdown(f"""<div class="insight">
670
- <b>Why it matters:</b> Hub neurons act as concept anchors.
671
- This is the basis for BDH's monosemantic synapses β€” neurons that consistently encode
672
- specific concepts (e.g. "currency synapse", "country synapse").
673
- </div>""", unsafe_allow_html=True)
674
-
675
- # ─────────────────────────────────── TAB 5 ───
676
- with tab5:
677
- st.markdown("""
678
- <div class="insight">
679
- Train tiny BDH and Transformer from scratch on random sequences.
680
- Watch BDH's activation rate converge toward ~5% as ReLU neurons learn selectivity.
681
- Transformer neurons stay dense throughout training.
682
- </div>""", unsafe_allow_html=True)
683
-
684
- n_steps = st.slider("Training steps", 50, 300, 150, step=50)
685
-
686
- if st.button("β–Ά Start Training", type="primary"):
687
- tcfg = BDHConfig(vocab_size=128, n_layer=2, n_head=4, n_embd=64)
688
- b_m = BDHModel(tcfg).eval()
689
- t_m = TransformerModel(tcfg).eval()
690
- ob = torch.optim.AdamW(b_m.parameters(), lr=3e-4)
691
- ot = torch.optim.AdamW(t_m.parameters(), lr=3e-4)
692
-
693
- b_log, t_log, b_loss_log, t_loss_log, xs = [], [], [], [], []
694
- prog = st.progress(0)
695
- ph = st.empty()
696
-
697
- def batch(V=128, B=2, T=24):
698
- x = torch.randint(0, V, (B, T))
699
- return x, torch.cat([x[:, 1:], x[:, :1]], dim=1)
700
-
701
- for step in range(n_steps):
702
- x, y = batch()
703
- b_m.train()
704
- lg, _ = b_m(x)
705
- lb = F.cross_entropy(lg.view(-1,128), y.view(-1))
706
- ob.zero_grad(); lb.backward(); ob.step()
707
-
708
- t_m.train()
709
- lt = F.cross_entropy(t_m(x).view(-1,128), y.view(-1))
710
- ot.zero_grad(); lt.backward(); ot.step()
711
-
712
- if step % 10 == 0 or step == n_steps-1:
713
- b_m.eval(); t_m.eval()
714
- tx = torch.randint(0, 128, (1, 24))
715
- ab = np.mean([s["frac_active"] for s in b_m.get_activation_stats(tx)]) * 100
716
- at = np.mean([s["frac_active"] for s in t_m.get_activation_stats(tx)]) * 100
717
- b_log.append(ab); t_log.append(at)
718
- b_loss_log.append(float(lb)); t_loss_log.append(float(lt))
719
- xs.append(step)
720
-
721
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 3.8), facecolor=BG)
722
- _ax(fig, [ax1, ax2])
723
- ax1.plot(xs, b_log, "o-", color=ORNG, lw=2, ms=4, label="BDH (ReLU)")
724
- ax1.plot(xs, t_log, "s-", color=BLUE, lw=2, ms=4, label="Transformer (GELU)")
725
- ax1.axhline(5, color=ORNG, ls="--", lw=1, alpha=0.5)
726
- ax1.axhline(100, color=BLUE, ls=":", lw=1, alpha=0.3)
727
- ax1.set_xlabel("Training step", color=MUTE); ax1.set_ylabel("% Active", color=MUTE)
728
- ax1.set_title("Activation Rate", color=TEXT, fontweight="bold", fontfamily="monospace")
729
- ax1.set_ylim(0, 110); ax1.yaxis.grid(True, color=GRID); ax1.set_axisbelow(True)
730
- ax1.legend(facecolor=CARD, edgecolor=GRID, labelcolor=TEXT, fontsize=8)
731
-
732
- ax2.plot(xs, b_loss_log, "-", color=ORNG, lw=2, label="BDH loss")
733
- ax2.plot(xs, t_loss_log, "-", color=BLUE, lw=2, label="Transformer loss")
734
- ax2.set_xlabel("Training step", color=MUTE); ax2.set_ylabel("Loss", color=MUTE)
735
- ax2.set_title("Training Loss", color=TEXT, fontweight="bold", fontfamily="monospace")
736
- ax2.yaxis.grid(True, color=GRID); ax2.set_axisbelow(True)
737
- ax2.legend(facecolor=CARD, edgecolor=GRID, labelcolor=TEXT, fontsize=8)
738
-
739
- fig.tight_layout(pad=1.2)
740
- ph.pyplot(fig, use_container_width=True); plt.close(fig)
741
- prog.progress((step+1)/n_steps)
742
-
743
- st.success(f"Done β€” BDH: **{b_log[-1]:.1f}%** active Β· Transformer: **{t_log[-1]:.1f}%** active")
744
- st.markdown("""
745
- <div class="insight" style="margin-top:0.8rem">
746
- BDH's ReLU neurons learned <b>selectivity</b> during training β€” firing only for strongly relevant inputs.
747
- Transformer GELU neurons stayed dense. This selectivity is the foundation of BDH's interpretability.
748
- </div>""", unsafe_allow_html=True)
749
-
750
- # ── Footer ──
751
- st.markdown("<hr class='divider'>", unsafe_allow_html=True)
752
- st.markdown(f"""
753
- <div style="text-align:center;font-family:'Space Mono',monospace;font-size:0.72rem;color:{MUTE};padding-bottom:1rem">
754
- Built for the Beyond Transformers Hackathon Β· Pathway Γ— IIT Ropar E-Summit '26 &nbsp;Β·&nbsp;
755
- <a href="https://arxiv.org/abs/2509.26507" style="color:{ORNG};text-decoration:none">arXiv:2509.26507</a> &nbsp;Β·&nbsp;
756
- <a href="https://github.com/pathwaycom/bdh" style="color:{ORNG};text-decoration:none">github.com/pathwaycom/bdh</a>
757
- </div>""", unsafe_allow_html=True)
758
-
759
-
760
- if __name__ == "__main__":
761
- main()