Breaking the O(N^2) Bottleneck: Implementing High-Performance Block-Sparse Attention with JAX/Pallas
Standard dense attention scales quadratically. This means it requires $O(N^2)$ compute and memory with respect to sequence length. If you double the context size, you multiply the workload by four. Scaling models to handle entire books, large codebases, or long conversational histories requires a fundamental shift in how we approach the algorithm.
One of the most practical solutions we have right now is Block-Sparse Attention. Instead of calculating attention scores between every single token and every other token, we restrict the mechanism. We only evaluate specific blocks of the context matrix—like a sliding local window, or specific routing tokens. This brings our complexity down to something much more manageable, like $O(N \log N)$ or $O(N \sqrt{N})$.
But here's the catch: implementing true block-sparse attention in high-level frameworks like standard JAX or PyTorch is famously frustrating.
Standard compilers like XLA are heavily optimized for static shapes and predictable, dense memory access patterns. If you try to introduce a dynamic, data-dependent sparsity mask, the compiler usually throws its hands up. It falls back to computing the entire dense matrix, and then it just masks out the zeros afterward. You end up doing all the math for the zeros anyway, completely defeating the purpose of sparsity and wasting precious FLOPS and memory bandwidth.
To fix this, we need to drop down to the hardware level. In the JAX ecosystem, we do this using Pallas, an extension that lets us write custom kernels for TPUs (via Mosaic) and GPUs (via Triton). In this post, we’re going to walk through how to use Pallas's Scalar Prefetch capabilities to build a true Block-Sparse Matrix Multiplication kernel—the exact building block you need for Block-Sparse Attention.
The Hardware Reality of Sparse Data
To understand why standard XLA struggles, we have to look at how the physical hardware operates.
Modern ML accelerators, like Google’s TPUs, have staggering compute capabilities. For example, a single TPU v5e can hit 197 TFLOP/s of BF16 compute. However, High-Bandwidth Memory (HBM)—the main memory where your large tensors live—is relatively slow compared to the compute units.
To get good performance, data has to be loaded from HBM into incredibly fast, local SRAM (called VMEM on TPUs, or Shared Memory on GPUs) in large, contiguous chunks. This keeps the Matrix Multiply Units (MXUs) fed with data so they don't sit idle.
If we want to multiply a Query matrix ($Q$) with a Key matrix ($K^T$) but only compute the outputs for specific blocks dictated by a sparse routing mask, we can't just issue random, on-the-fly memory fetches. We need a way to tell the accelerator's DMA (Direct Memory Access) engine exactly which blocks to grab ahead of time. If we do this right, the memory transfers will happen in the background while the MXU is busy doing math on the previous block.
The Secret Weapon: Scalar Prefetch
Pallas gives us a tool for exactly this problem called PrefetchScalarGridSpec.
Instead of passing our attention mask as a massive, wasteful $N \times N$ matrix of booleans, we pass it in a Block-Coordinate (Block-COO) format. PrefetchScalarGridSpec allows us to load this small list of coordinates directly into the accelerator's Scalar Memory (SMEM) right before the main compute pipeline kicks off.
We can then use these scalar values inside our index mapping functions to dynamically route our memory fetches from HBM. The kernel essentially gets a set of instructions saying: "For loop iteration $i$, fetch block $X$ of the Queries and block $Y$ of the Keys."
Writing the Kernel
Let's write a Dense = Sparse @ Dense (DSD) matrix multiplication kernel. This represents the first half of the attention equation ($Logits = Q \times K^T$), where we only want to compute the logits for non-masked blocks.
First, let's set up our dimensions and define the signature of our kernel.
import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
# Define our dimensions for the overall matrices and our block sizes
M = N = K = 16384
blk_M = blk_N = blk_K = 512
def sparse_attention_dsd_kernel(
idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs (living in SMEM)
q_ref, k_ref, _, o_ref, # Kernel inputs (living in VMEM/HBM)
accum_scratch # Scratch space for partial sums (VMEM)
):
"""
A Dense = Sparse @ Dense matmul kernel for Sparse Attention.
Computes Q @ K^T only for blocks defined by the prefetch indices.
"""
# Program ID 1 represents our loop over the non-zero sparse blocks
blk_idx = pl.program_id(1)
is_start = (blk_idx == 0)
In the code above, notice how the arguments are ordered. Pallas always passes the scalar prefetch arguments (idxs_i_ref, idxs_k_ref) first. These live in SMEM. Then come our actual matrix blocks, and finally our scratch space which lives in VMEM.
Next, we need to handle our accumulation. Because we are looping over sparse blocks, we might be adding multiple $Q \times K^T$ chunks together to form a single output block. We need to know when we've moved on to a new output block so we can reset our accumulator to zero.
# Check if we have moved to a new row in our output matrix
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
# If we are starting a new output block, zero out the accumulator scratchpad
@pl.when(is_start | changed_blocks)
def _():
accum_scratch[...] = jnp.zeros_like(accum_scratch)
Now for the fun part: doing the actual math. TPUs have specialized Matrix Multiply Units (MXUs) that natively accumulate in 32-bit floating point to prevent precision loss, even if the inputs are 16-bit (like BF16). We use jax.lax.dot_general to tell Pallas to map this directly to the MXU.
# Perform the MXU matrix multiplication: Q_block @ K_block^T
accum_scratch[...] += jax.lax.dot_general(
q_ref[0, :, :],
k_ref[...],
(((1,), (0,)), ((), ())), # Standard matmul contraction dimensions
preferred_element_type=jnp.float32
)
Finally, we need to check if we are done accumulating for the current block. If we are, we cast the FP32 accumulator back down to the data type of our output matrix (usually BF16) and write it out to HBM.
# Find out how many total blocks we are processing
num_blocks = pl.num_programs(1)
# Check if the next iteration will move us to a different output block
next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
is_end = (blk_idx == num_blocks - 1)
# If we are done with this block, cast back to BF16 and write to HBM
@pl.when(is_end | next_block_change)
def _():
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
Mapping the Pipeline
Our kernel logic is complete, but we haven't actually told Pallas how to fetch the data from HBM to feed into q_ref and k_ref.
This is where BlockSpec and index_map come into play. These mapping functions actually run on the scalar core of the TPU. Notice how q_map and k_map intercept the blk_idxs_i and blk_idxs_k arrays that we pre-loaded into SMEM. They use these arrays to dynamically slice the massive HBM matrices.
# Map the Query block to fetch based on the sparse index
def q_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
# In this specific example, Q might be packed contiguously
return (blk_idx, 0, 0)
# Map the Key block to fetch based on the sparse index from SMEM
def k_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
return (blk_idxs_k[blk_idx], j)
# Map the Output logit block based on the sparse index from SMEM
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
return (blk_idxs_i[blk_idx], j)
To tie it all together, we write a Python wrapper function. This function configures the grid, defines the scratch space, and calls the kernel.
There is a clever trick happening here with input_output_aliases. Because we only want to compute specific blocks, any blocks we don't visit should remain zero. By passing an array of zeros as an input and aliasing it to our output, we guarantee the unvisited blocks are safely ignored and zeroed out without doing any extra work.
def block_sparse_matmul(Q_sparse_blocks, K_dense, indices_i, indices_k, num_nonzero_blocks):
"""
Wrapper to configure and call the Pallas kernel
"""
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)
# Configure the Grid and tell it to expect 2 scalar prefetch arguments
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
grid=(N // blk_N, num_nonzero_blocks),
in_specs=[
pl.BlockSpec((1, blk_M, blk_K), q_map),
pl.BlockSpec((blk_K, blk_N), k_map),
pl.BlockSpec((blk_M, blk_N), o_map), # Dummy mapping for our zero-alias
],
out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
sparse_attention_dsd_kernel,
grid_spec=grid_spec,
out_shape=out_shape,
input_output_aliases={4: 0}, # Alias input 4 (the zeros) to output 0
)
# Create the zero tensor for aliasing
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
# Execute the kernel!
return kernel(indices_i, indices_k, Q_sparse_blocks, K_dense, zeros)
The Performance Payoff
By using Pallas's scalar prefetch, we completely sidestep XLA's requirement for static, dense layouts.
If your attention mask dictates that only 10% of the blocks are relevant (for example, if tokens are only allowed to attend to a local sliding window of 1024 tokens plus a few global anchor tokens), this kernel will execute only the MXU instructions and memory fetches required for that 10%.
Even better, Pallas handles the software pipelining (specifically, double-buffering) automatically under the hood. While the MXU is busy crunching the math for the current block, the DMA engine is asynchronously fetching the Key and Query data for the next block in the background.
When you run this exact DSD pattern on a TPU v5e with a 10% sparsity factor, it yields roughly a ~6x speedup over a dense matrix multiplication mapped with an element-wise mask. It isn't a perfect 10x speedup because of pipeline bubbles and some fixed overhead, but it is a massive architectural win.
The real beauty of Pallas is the flexibility it gives you. You are no longer constrained by the operators baked into the framework, and you don't have to wait for someone else to write a custom C++ or CUDA kernel when a new sparse attention paper drops. By shifting your thinking to HBM, VMEM, and SMEM, you can build entirely new paradigms of sparse, linear-time attention mechanisms directly in Python.