Qwen2.5-0.5B Layer 19 TopK SAE (16x, k=32)

Sparse autoencoder trained on layer 19 residual stream of Qwen/Qwen2.5-0.5B, using activations collected on 50K ARC-AGI tasks (2.82B tokens).

Summary

  • Base model: Qwen/Qwen2.5-0.5B
  • Hook point: blocks.19.hook_resid_post
  • Architecture: TopK
  • d_in: 896
  • d_sae: 14336 (16x expansion)
  • Sparsity: k=32 active features per token (0.22% of 14336)

16x expansion A/B against v2. Same data and training budget.

Checkpoints

This repo contains 5 training checkpoints as subfolders:

Step d_sae Notes
step_00110000 14336
step_00125000 14336
step_00135000 14336
step_00145000 14336
step_00150000 14336

Each subfolder is a self-contained SAELens-format checkpoint (cfg.json + sae_weights.safetensors).

Usage with SAELens

from sae_lens import SAE

# Load a specific training step
sae = SAE.from_pretrained_with_cfg_and_sparsity(
    release="KathirKs/qwen2.5-0.5b-l19-sae-topk-16x",
    sae_id="step_00125000",
)[0]

# Or load from a local clone
# sae = SAE.load_from_disk("./step_00125000")

Usage with sae_vis

from sae_vis import SaeVisConfig, SaeVisData
# pair with HookedTransformer("Qwen/Qwen2.5-0.5B") + this SAE at blocks.19.hook_resid_post

Training

  • Global batch: 4096 tokens
  • Steps: 150,000
  • Learning rate: 3e-4 cosine, 1000-step warmup
  • Optimizer: Adam
  • Dead-neuron resampling: every 25K steps, stopped at step 75K
  • Decoder init norm: 0.1, decoder weights kept unit-norm during training
  • Hardware: v5litepod-64 TPU (16 hosts, data-parallel)

Conversion

Converted from the original JAX/Flax checkpoints with scripts/convert_sae_to_saelens.py. The round-trip was numerically verified against a numpy reference (encode() max-abs diff < 1e-4, L0 exact match).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for KathirKs/qwen2.5-0.5b-l19-sae-topk-16x

Finetuned
(592)
this model