TPU Architecture and Pallas Kernels: From Memory Hierarchy to FlashAttention

This article explains why TPU programming differs from GPU, describes the explicit HBM‑VMEM‑register data movement required on TPU, introduces the Pallas grid‑BlockSpec‑Ref model, and walks through four progressively more complex kernels—including element‑wise add, tiled dot product, fused RMSNorm with scratch memory, and a production‑grade FlashAttention implementation—showing how each kernel maps to the TPU memory hierarchy and leverages Pallas features such as input_output_aliases and PrefetchScalarGridSpec.

DeepHub IMBA
DeepHub IMBA
DeepHub IMBA
TPU Architecture and Pallas Kernels: From Memory Hierarchy to FlashAttention

TPUs are specialized matrix‑multiply accelerators that lack automatic cache management; data must be moved explicitly between high‑bandwidth memory (HBM), on‑chip vector memory (VMEM), and registers via DMA transfers. Unlike GPUs, a TPU kernel cannot be written as a simple loop because the compiler needs explicit instructions for each data movement.

TPU Memory Hierarchy

HBM (≈16 GB on v5e) – off‑chip storage, slower access.

VMEM (16 + MB on‑chip SRAM) – fast but limited capacity; data must reside here for the compute units.

Registers – where arithmetic is performed; values are loaded from VMEM and written back after computation.

Because the hardware does not automatically copy data from HBM to registers, programmers must schedule DMA loads from HBM to VMEM before a kernel runs and DMA stores from VMEM back to HBM after execution.

Pallas Programming Model

Pallas abstracts the explicit data movement with three concepts:

Grid : an integer tuple defining the iteration space.

BlockSpec : describes which slice of an HBM tensor is loaded into VMEM for each grid step.

Ref : a handle to a VMEM buffer that the kernel reads from or writes to.

These abstractions let the compiler generate the necessary DMA operations while the user focuses on the computation.

Level 4 – Element‑wise Add Kernel

import jax, jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
    block_size = 256
    return pl.pallas_call(
        add_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(x.shape[0] // block_size,),
        in_specs=[pl.BlockSpec((block_size,), lambda i: (i,)),
                  pl.BlockSpec((block_size,), lambda i: (i,))],
        out_specs=pl.BlockSpec((block_size,), lambda i: (i,))
    )(x, y)

The kernel processes a 1024‑element vector in four 256‑element tiles. grid=(4,) creates four sequential steps; each step loads a slice from HBM to VMEM, performs the addition in registers, and writes the result back.

Level 3 – Tiled Dot‑Product Kernel

def dot_kernel(x_ref, y_ref, acc_in_ref, out_ref):
    out_ref[...] += jnp.sum(x_ref[...] * y_ref[...], keepdims=True)

def tiled_dot(x: jax.Array, y: jax.Array) -> jax.Array:
    block_size = 256
    n_blocks = x.shape[0] // block_size
    zero = jnp.zeros((1,), dtype=jnp.float32)
    return pl.pallas_call(
        dot_kernel,
        out_shape=jax.ShapeDtypeStruct((1,), jnp.float32),
        grid=(n_blocks,),
        in_specs=[pl.BlockSpec((block_size,), lambda i: (i,)),
                  pl.BlockSpec((block_size,), lambda i: (i,)),
                  pl.BlockSpec((1,), lambda i: (0,))],
        out_specs=pl.BlockSpec((1,), lambda i: (0,)),
        input_output_aliases={2: 0}
    )(x, y, zero)

Here input_output_aliases={2:0} makes the accumulator buffer share storage with the output buffer, enabling a running sum across grid steps without extra copies.

Level 2 – Fused RMSNorm with Scratch Memory

def rmsnorm_kernel(x_ref, weight_ref, o_ref, scratch_ref):
    x = x_ref[...].astype(jnp.float32)
    w = weight_ref[...].astype(jnp.float32)
    mean_sq = jnp.mean(x * x, axis=-1, keepdims=True)
    scratch_ref[...] = jnp.broadcast_to(mean_sq, scratch_ref.shape)
    rms = jnp.sqrt(scratch_ref[0:BATCH, 0:1] + EPS)
    o_ref[...] = (x / rms * w).astype(jnp.bfloat16)

def fused_rmsnorm(x: jax.Array, weight: jax.Array) -> jax.Array:
    grid_spec = pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        grid=(1,),
        in_specs=[pl.BlockSpec((BATCH, DIM), lambda i: (0, 0)),
                  pl.BlockSpec((DIM,), lambda i: (0,))],
        out_specs=pl.BlockSpec((BATCH, DIM), lambda i: (0, 0)),
        scratch_shapes=[pltpu.VMEM((BATCH, 128), jnp.float32)]
    )
    return pl.pallas_call(rmsnorm_kernel, grid_spec=grid_spec,
        out_shape=jax.ShapeDtypeStruct((BATCH, DIM), jnp.bfloat16))(x, weight)

The scratch_shapes argument allocates a temporary VMEM buffer that holds the per‑tile mean‑square values; this buffer never touches HBM, reducing traffic.

Level 1 – FlashAttention Kernel

def flash_kernel(q_ref, k_ref, v_ref, o_ref, m_ref, l_ref, acc_ref, *, num_kv_blocks, sm_scale):
    kv_idx = pl.program_id(axis=1)
    @pl.when(kv_idx == 0)
    def init():
        m_ref[...] = jnp.full(m_ref.shape, -jnp.inf, jnp.float32)
        l_ref[...] = jnp.zeros(l_ref.shape, jnp.float32)
        acc_ref[...] = jnp.zeros(acc_ref.shape, jnp.float32)
    q = q_ref[...].astype(jnp.float32)
    k = k_ref[...].astype(jnp.float32)
    v = v_ref[...]
    s = jax.lax.dot_general(q, k, (((1,), (1,)), ((), ())) ) * sm_scale
    m_prev = m_ref[...]
    m_curr = jnp.max(s, axis=1)[:, None]
    m_next = jnp.maximum(m_prev, m_curr)
    alpha = jnp.exp(m_prev - m_next)
    p = jnp.exp(s - m_next)
    l_ref[...] = alpha * l_ref[...] + jnp.sum(p, axis=1)[:, None]
    acc_ref[...] = alpha * acc_ref[...] + jax.lax.dot(p.astype(v.dtype), v)
    m_ref[...] = m_next
    @pl.when(kv_idx == num_kv_blocks - 1)
    def store():
        o_ref[...] = (acc_ref[...] / l_ref[...]).astype(o_ref.dtype)

The kernel uses a two‑dimensional grid (query blocks × key/value blocks). It maintains three VMEM scratch buffers: m (running max), l (running softmax denominator), and acc (running weighted sum). The pl.when primitive performs conditional initialization on the first KV block and final normalization on the last block, implementing an online softmax that avoids materializing the full attention matrix.

Conclusion

TPU programming requires explicit data movement across HBM, VMEM, and registers. Pallas abstracts this with Grid, BlockSpec, and Ref, letting developers write kernels that focus on the mathematical logic while the compiler inserts the necessary DMA operations. The four kernels demonstrated—element‑wise add, tiled dot product, fused RMSNorm, and FlashAttention—show how increasingly complex patterns (accumulation, scratch buffers, prefetching, and online softmax) are expressed using Pallas primitives.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

FlashAttentionMemory HierarchyJAXTPUkernel programmingPallas
DeepHub IMBA
Written by

DeepHub IMBA

A must‑follow public account sharing practical AI insights. Follow now. internet + machine learning + big data + architecture = IMBA

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.