Accelerating LLM Inference: Fused INT8 Weight-Only Quantization in Pallas
In hardware terms, we call this being Memory Bandwidth Bound.
Let's break down the math to see why this happens. Imagine a standard linear layer in an LLM with an $8192 \times 8192$ weight matrix. To do the matrix multiplication for this layer, your accelerator needs to perform roughly $2 \times 8192^3$ operations, which comes out to about 1.1 Trillion Floating Point Operations (FLOPs).
To actually do that math, the hardware has to read the weights from its High Bandwidth Memory (HBM). If those weights are stored in bfloat16 format (which takes 2 bytes per number), we are looking at reading about 134 Megabytes of memory just for this one layer.
Now, look at a modern accelerator like Google's TPU v5e. It has an immense compute capacity of 197 TFLOP/s, but its memory bandwidth is "only" 819 GB/s.
When you are generating text auto-regressively (meaning a batch size of 1, generating one token at a time), you have to load that entire 134 MB weight matrix from memory just to multiply it against a single token vector. The ratio of math-to-memory-movement here—called arithmetic intensity—is incredibly low. It sits at roughly 1 FLOP per byte. Because memory is the slow path, those massive Matrix Multiply Units (MXUs on TPUs, or Tensor Cores on GPUs) end up sitting around twiddling their thumbs, waiting for the weights to travel across the bus.
The INT8 Solution (and the XLA Trap)
The most practical way to fix this is Weight-Only Quantization. If we store the LLM's weights as INT8 (1 byte) instead of BF16 (2 bytes), we instantly cut our memory bandwidth requirements in half. If we push it to INT4 (0.5 bytes), we quadruple our efficiency.
But there's a catch. Accelerators generally don't support native mixed-precision matrix multiplication. You can't just hand the hardware BF16 activations and INT8 weights and expect it to spit out BF16 logits.
If you try to write this naively in standard JAX, the XLA compiler tries to be helpful but ends up shooting you in the foot. XLA will look at your code, realize the types don't match, and decide to dequantize the INT8 weights back into BF16 before doing the matrix multiplication.
The fatal flaw? XLA often writes that newly inflated BF16 tensor back out to the slow High Bandwidth Memory, and then reads it back in for the math. You just paid the memory bandwidth cost for BF16 anyway, completely destroying the speedup you were trying to achieve!
Taking Control with Pallas
To get around this, we need to tell the hardware exactly how to move the data. Using JAX's Pallas extension, we can write a custom kernel that does exactly what we want:
- Stream the compressed INT8 weights from slow HBM into the ultra-fast, local SRAM.
- Dequantize those weights into BF16 on the fly, entirely within the fast memory.
- Feed them directly into the compute units.
The heavy BF16 weights never touch HBM.
Let's build this fused dequantize-and-matmul kernel piece by piece. First, we need to set up the kernel signature and initialize our accumulator. We need our bfloat16 activations (x_ref), our int8 compressed weights (w_int8_ref), and the bfloat16 scaling factors used to unpack the weights (scale_ref).
import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import functools
def fused_dequant_matmul_kernel(
x_ref, w_int8_ref, scale_ref, z_ref, acc_ref,
nsteps: int
):
"""
Pallas Kernel: Computes X @ (W_int8 * Scale) entirely in fast memory.
"""
# 1. Initialize the accumulator for the reduction dimension.
# We only want to do this on the very first step of our loop.
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
Now for the crucial part: the dequantization. Because we are inside a Pallas kernel, the arrays we are manipulating right now (w_int8_ref and scale_ref) have already been fetched into the fast local memory (VMEM on TPUs).
When we cast the INT8 weights to BF16 and multiply by the scale, that new BF16 tensor lives entirely in fast memory. We bypassed the HBM bottleneck completely.
# 2. Dequantization step (happens in fast VMEM/SRAM)
# Cast to BF16 and multiply by the scale block.
w_bf16 = w_int8_ref[...].astype(jnp.bfloat16) * scale_ref[...]
Next, we hand those unpacked weights directly to the hardware's native matrix multiplication unit. We use jax.lax.dot_general to handle the math, and we explicitly tell it to accumulate the results in float32. This keeps our math highly accurate, preventing precision loss as we add up thousands of numbers.
# 3. Compute step
# Execute the block matrix multiplication on the MXU/TensorCore
acc_ref[...] += jax.lax.dot_general(
x_ref[...],
w_bf16,
(((1,), (0,)), ((), ())), # Contract dimension 1 of X with 0 of W
preferred_element_type=jnp.float32 # Accumulate in higher precision
)
Finally, we need an epilogue. Once we've looped over all the chunks in the reduction dimension (the K dimension), our accumulator holds the final answer. We cast it back to our target data type and write it out to z_ref, which safely stores the final logits back into main memory.
# 4. Epilogue
# Once we have iterated over all blocks in the K dimension, write out to HBM
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = acc_ref[...].astype(z_ref.dtype)
Structuring the Pipeline
Our kernel logic is solid, but the hardware can't process an $8192 \times 8192$ matrix all at once. The local SRAM is far too small. We have to chunk the matrices into smaller blocks and feed them through the kernel one by one.
We do this using pallas_call and BlockSpec. Tuning these block dimensions (bm, bk, bn) is the real secret to squeezing maximum speed out of your hardware. You want blocks large enough to keep the math units busy, but small enough to fit in SRAM.
Here is a really cool detail: because our weights are stored in INT8, they take up half the space of BF16 weights. That means we can actually make our block size for the K dimension (bk) twice as large as we normally would!
Let's write the Python wrapper that chunks the data and calls our kernel.
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def quantized_matmul(
X: jax.Array, # Shape: (M, K), dtype: bfloat16
W_int8: jax.Array, # Shape: (K, N), dtype: int8
Scales: jax.Array, # Shape: (K, N), dtype: bfloat16
bm: int = 128,
bk: int = 256, # We can make bk larger because INT8 takes less space!
bn: int = 128,
):
M, K = X.shape
_, N = W_int8.shape
# Calculate how many loop iterations the pipeline will do over the K dim
nsteps = K // bk
# Bind the static arguments to our kernel
kernel_fn = functools.partial(fused_dequant_matmul_kernel, nsteps=nsteps)
Now we configure the grid. This tells Pallas how to slice up the M, N, and K dimensions, and exactly how those slices map to our inputs.
Notice that w_int8_ref uses the exact same block layout mapping as a standard matrix multiplication. But because its dtype is int8, Pallas is smart enough to know it only needs to fetch half the bytes from HBM.
return pl.pallas_call(
kernel_fn,
out_shape=jax.ShapeDtypeStruct((M, N), X.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda m, n, k: (m, k)), # X block
pl.BlockSpec((bk, bn), lambda m, n, k: (k, n)), # W_int8 block
pl.BlockSpec((bk, bn), lambda m, n, k: (k, n)), # Scales block
],
out_specs=pl.BlockSpec((bm, bn), lambda m, n, k: (m, n)),
# We explicitly allocate our FP32 accumulator in fast VMEM
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
# Define how many total chunks exist in each dimension
grid=(M // bm, N // bn, K // bk),
),
# Hint to the compiler to parallelize over M and N, but run K sequentially
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "parallel", "sequential")
)
)(X, W_int8, Scales)
Why This Changes the Game
By dropping down to Pallas and writing this custom kernel, we have completely transformed the memory profile of our model.
Think about the math again. In a standard BF16 setup, reading a $K \times N$ weight matrix costs $K \times N \times 2$ bytes of memory bandwidth. In our custom Pallas pipeline, we only read the INT8 weights, which costs $K \times N \times 1$ bytes. Even when you factor in reading the scale tensor, that cost is tiny because scales are usually grouped (for example, you might only have one scale value for every 64 weights). The end result is that we effectively double our memory bandwidth efficiency.
What makes Pallas really special is how it abstracts the brutal complexity of hardware compilation. On TPUs, this code compiles down to Mosaic, which natively sets up the double-buffering needed to fetch the next block of data while the current block is doing math. If we wanted to run this on an NVIDIA GPU, we could swap out the pltpu annotations for plgpu. Pallas would then compile it through Triton or Mosaic GPU, using NVIDIA’s Tensor Memory Accelerator (TMA) to stream the INT8 weights into Shared Memory, handling all the asynchronous warp-specialization for us.
We are reaching a point where writing custom hardware kernels is no longer something only hardcore C++ and CUDA engineers can do. With tools like JAX and Pallas, researchers and developers can implement advanced quantization, sparsity, and memory optimizations directly in Python, squeezing every last drop of performance out of the hardware.