Accelerating LLM Inference: Fused INT8 Weight-Only Quantization in Pallas

Community Article Published March 30, 2026

If you’ve ever watched an LLM generate text one word at a time and wondered what exactly is taking so long, the answer might surprise you. The bottleneck in modern Large Language Model inference isn't the math. The processors we use today are incredibly fast at crunching numbers. The real problem is simply getting the data to the processors fast enough.

In hardware terms, we call this being Memory Bandwidth Bound.

Let's break down the math to see why this happens. Imagine a standard linear layer in an LLM with an $8192 \times 8192$ weight matrix. To do the matrix multiplication for this layer, your accelerator needs to perform roughly $2 \times 8192^3$ operations, which comes out to about 1.1 Trillion Floating Point Operations (FLOPs).

To actually do that math, the hardware has to read the weights from its High Bandwidth Memory (HBM). If those weights are stored in bfloat16 format (which takes 2 bytes per number), we are looking at reading about 134 Megabytes of memory just for this one layer.

Now, look at a modern accelerator like Google's TPU v5e. It has an immense compute capacity of 197 TFLOP/s, but its memory bandwidth is "only" 819 GB/s.

When you are generating text auto-regressively (meaning a batch size of 1, generating one token at a time), you have to load that entire 134 MB weight matrix from memory just to multiply it against a single token vector. The ratio of math-to-memory-movement here—called arithmetic intensity—is incredibly low. It sits at roughly 1 FLOP per byte. Because memory is the slow path, those massive Matrix Multiply Units (MXUs on TPUs, or Tensor Cores on GPUs) end up sitting around twiddling their thumbs, waiting for the weights to travel across the bus.

The INT8 Solution (and the XLA Trap)

The most practical way to fix this is Weight-Only Quantization. If we store the LLM's weights as INT8 (1 byte) instead of BF16 (2 bytes), we instantly cut our memory bandwidth requirements in half. If we push it to INT4 (0.5 bytes), we quadruple our efficiency.

But there's a catch. Accelerators generally don't support native mixed-precision matrix multiplication. You can't just hand the hardware BF16 activations and INT8 weights and expect it to spit out BF16 logits.

If you try to write this naively in standard JAX, the XLA compiler tries to be helpful but ends up shooting you in the foot. XLA will look at your code, realize the types don't match, and decide to dequantize the INT8 weights back into BF16 before doing the matrix multiplication.

The fatal flaw? XLA often writes that newly inflated BF16 tensor back out to the slow High Bandwidth Memory, and then reads it back in for the math. You just paid the memory bandwidth cost for BF16 anyway, completely destroying the speedup you were trying to achieve!

Taking Control with Pallas

To get around this, we need to tell the hardware exactly how to move the data. Using JAX's Pallas extension, we can write a custom kernel that does exactly what we want:

  1. Stream the compressed INT8 weights from slow HBM into the ultra-fast, local SRAM.
  2. Dequantize those weights into BF16 on the fly, entirely within the fast memory.
  3. Feed them directly into the compute units.

The heavy BF16 weights never touch HBM.

Let's build this fused dequantize-and-matmul kernel piece by piece. First, we need to set up the kernel signature and initialize our accumulator. We need our bfloat16 activations (x_ref), our int8 compressed weights (w_int8_ref), and the bfloat16 scaling factors used to unpack the weights (scale_ref).

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import functools

def fused_dequant_matmul_kernel(
    x_ref, w_int8_ref, scale_ref, z_ref, acc_ref, 
    nsteps: int
):
    """
    Pallas Kernel: Computes X @ (W_int8 * Scale) entirely in fast memory.
    """
    # 1. Initialize the accumulator for the reduction dimension.
    # We only want to do this on the very first step of our loop.
    @pl.when(pl.program_id(2) == 0)
    def _():
        acc_ref[...] = jnp.zeros_like(acc_ref)

Now for the crucial part: the dequantization. Because we are inside a Pallas kernel, the arrays we are manipulating right now (w_int8_ref and scale_ref) have already been fetched into the fast local memory (VMEM on TPUs).

When we cast the INT8 weights to BF16 and multiply by the scale, that new BF16 tensor lives entirely in fast memory. We bypassed the HBM bottleneck completely.

    # 2. Dequantization step (happens in fast VMEM/SRAM)
    # Cast to BF16 and multiply by the scale block.
    w_bf16 = w_int8_ref[...].astype(jnp.bfloat16) * scale_ref[...]

Next, we hand those unpacked weights directly to the hardware's native matrix multiplication unit. We use jax.lax.dot_general to handle the math, and we explicitly tell it to accumulate the results in float32. This keeps our math highly accurate, preventing precision loss as we add up thousands of numbers.

    # 3. Compute step
    # Execute the block matrix multiplication on the MXU/TensorCore
    acc_ref[...] += jax.lax.dot_general(
        x_ref[...], 
        w_bf16, 
        (((1,), (0,)), ((), ())), # Contract dimension 1 of X with 0 of W
        preferred_element_type=jnp.float32 # Accumulate in higher precision
    )

Finally, we need an epilogue. Once we've looped over all the chunks in the reduction dimension (the K dimension), our accumulator holds the final answer. We cast it back to our target data type and write it out to z_ref, which safely stores the final logits back into main memory.

    # 4. Epilogue
    # Once we have iterated over all blocks in the K dimension, write out to HBM
    @pl.when(pl.program_id(2) == nsteps - 1)
    def _():
        z_ref[...] = acc_ref[...].astype(z_ref.dtype)

Structuring the Pipeline

Our kernel logic is solid, but the hardware can't process an $8192 \times 8192$ matrix all at once. The local SRAM is far too small. We have to chunk the matrices into smaller blocks and feed them through the kernel one by one.

We do this using pallas_call and BlockSpec. Tuning these block dimensions (bm, bk, bn) is the real secret to squeezing maximum speed out of your hardware. You want blocks large enough to keep the math units busy, but small enough to fit in SRAM.

Here is a really cool detail: because our weights are stored in INT8, they take up half the space of BF16 weights. That means we can actually make our block size for the K dimension (bk) twice as large as we normally would!

Let's write the Python wrapper that chunks the data and calls our kernel.

@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def quantized_matmul(
    X: jax.Array,        # Shape: (M, K), dtype: bfloat16
    W_int8: jax.Array,   # Shape: (K, N), dtype: int8
    Scales: jax.Array,   # Shape: (K, N), dtype: bfloat16
    bm: int = 128,
    bk: int = 256,       # We can make bk larger because INT8 takes less space!
    bn: int = 128,
):
    M, K = X.shape
    _, N = W_int8.shape
    
    # Calculate how many loop iterations the pipeline will do over the K dim
    nsteps = K // bk
    
    # Bind the static arguments to our kernel
    kernel_fn = functools.partial(fused_dequant_matmul_kernel, nsteps=nsteps)

Now we configure the grid. This tells Pallas how to slice up the M, N, and K dimensions, and exactly how those slices map to our inputs.

Notice that w_int8_ref uses the exact same block layout mapping as a standard matrix multiplication. But because its dtype is int8, Pallas is smart enough to know it only needs to fetch half the bytes from HBM.

    return pl.pallas_call(
        kernel_fn,
        out_shape=jax.ShapeDtypeStruct((M, N), X.dtype),
        grid_spec=pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                pl.BlockSpec((bm, bk), lambda m, n, k: (m, k)), # X block
                pl.BlockSpec((bk, bn), lambda m, n, k: (k, n)), # W_int8 block
                pl.BlockSpec((bk, bn), lambda m, n, k: (k, n)), # Scales block
            ],
            out_specs=pl.BlockSpec((bm, bn), lambda m, n, k: (m, n)),
            
            # We explicitly allocate our FP32 accumulator in fast VMEM
            scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
            
            # Define how many total chunks exist in each dimension
            grid=(M // bm, N // bn, K // bk),
        ),
        # Hint to the compiler to parallelize over M and N, but run K sequentially
        compiler_params=pltpu.CompilerParams(
            dimension_semantics=("parallel", "parallel", "sequential")
        )
    )(X, W_int8, Scales)

Why This Changes the Game

By dropping down to Pallas and writing this custom kernel, we have completely transformed the memory profile of our model.

Think about the math again. In a standard BF16 setup, reading a $K \times N$ weight matrix costs $K \times N \times 2$ bytes of memory bandwidth. In our custom Pallas pipeline, we only read the INT8 weights, which costs $K \times N \times 1$ bytes. Even when you factor in reading the scale tensor, that cost is tiny because scales are usually grouped (for example, you might only have one scale value for every 64 weights). The end result is that we effectively double our memory bandwidth efficiency.

What makes Pallas really special is how it abstracts the brutal complexity of hardware compilation. On TPUs, this code compiles down to Mosaic, which natively sets up the double-buffering needed to fetch the next block of data while the current block is doing math. If we wanted to run this on an NVIDIA GPU, we could swap out the pltpu annotations for plgpu. Pallas would then compile it through Triton or Mosaic GPU, using NVIDIA’s Tensor Memory Accelerator (TMA) to stream the INT8 weights into Shared Memory, handling all the asynchronous warp-specialization for us.

We are reaching a point where writing custom hardware kernels is no longer something only hardcore C++ and CUDA engineers can do. With tools like JAX and Pallas, researchers and developers can implement advanced quantization, sparsity, and memory optimizations directly in Python, squeezing every last drop of performance out of the hardware.

Community

Sign up or log in to comment