Qwen2.5-0.5B Layer 19 TopK SAE (16x, k=32)
Sparse autoencoder trained on layer 19 residual stream of Qwen/Qwen2.5-0.5B,
using activations collected on 50K ARC-AGI tasks (2.82B tokens).
Summary
- Base model:
Qwen/Qwen2.5-0.5B - Hook point:
blocks.19.hook_resid_post - Architecture: TopK
- d_in: 896
- d_sae: 14336 (16x expansion)
- Sparsity: k=32 active features per token (0.22% of 14336)
16x expansion A/B against v2. Same data and training budget.
Checkpoints
This repo contains 5 training checkpoints as subfolders:
| Step | d_sae | Notes |
|---|---|---|
step_00110000 |
14336 | |
step_00125000 |
14336 | |
step_00135000 |
14336 | |
step_00145000 |
14336 | |
step_00150000 |
14336 |
Each subfolder is a self-contained SAELens-format checkpoint (cfg.json +
sae_weights.safetensors).
Usage with SAELens
from sae_lens import SAE
# Load a specific training step
sae = SAE.from_pretrained_with_cfg_and_sparsity(
release="KathirKs/qwen2.5-0.5b-l19-sae-topk-16x",
sae_id="step_00125000",
)[0]
# Or load from a local clone
# sae = SAE.load_from_disk("./step_00125000")
Usage with sae_vis
from sae_vis import SaeVisConfig, SaeVisData
# pair with HookedTransformer("Qwen/Qwen2.5-0.5B") + this SAE at blocks.19.hook_resid_post
Training
- Global batch: 4096 tokens
- Steps: 150,000
- Learning rate: 3e-4 cosine, 1000-step warmup
- Optimizer: Adam
- Dead-neuron resampling: every 25K steps, stopped at step 75K
- Decoder init norm: 0.1, decoder weights kept unit-norm during training
- Hardware: v5litepod-64 TPU (16 hosts, data-parallel)
Conversion
Converted from the original JAX/Flax checkpoints with
scripts/convert_sae_to_saelens.py.
The round-trip was numerically verified against a numpy reference
(encode() max-abs diff < 1e-4, L0 exact match).
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
Model tree for KathirKs/qwen2.5-0.5b-l19-sae-topk-16x
Base model
Qwen/Qwen2.5-0.5B