DecodeBatch Load Imbalance in LLM Inference: Request Length Differences Amplify

During LLM decoding, the DecodeBatch stage can suffer severe load imbalance because differing historical token lengths (kv_len) cause uneven attention task distribution across GPU SMs, a problem explored through detailed analysis of task granularity, SplitKV heuristics, FlashInfer’s batch‑size thresholds, and FA3’s dynamic scheduling and split strategies.

Machine Learning Algorithms & Natural Language Processing
Machine Learning Algorithms & Natural Language Processing
Machine Learning Algorithms & Natural Language Processing
DecodeBatch Load Imbalance in LLM Inference: Request Length Differences Amplify

Decode Batch Load Imbalance

In a typical GQA architecture the decode stage processes one token per request ( q_len=1). The kernel execution time is roughly proportional to the total number of tokens only when the KV sequence lengths ( kv_len) of all requests are similar. If kv_len varies widely, some SMs finish early while others remain busy, breaking the linear relationship.

Attention Task Partitioning

M/N Blocking

FlashAttention‑style kernels compute S = QKᵀ block‑wise. For a single request and head the tensors are

Q_{b,h}: [1, D]
K_{b,h}: [kv_len, D]
V_{b,h}: [kv_len, D]
S_{b,h} = Q_{b,h} K_{b,h}ᵀ   // [1, kv_len]
BLOCK_M

is usually 128; BLOCK_N is 128 or 176. Because q_len=1, num_m_blocks = ceil(1 / BLOCK_M) = 1. Variation therefore comes from the N direction ( kv_len).

M/N blocking diagram
M/N blocking diagram

Basic Attention Task

A basic task is identified by (m_block, request, kv_head). In decode the M direction has only one effective block, so the task iterates over all KV tiles of a request. Longer kv_len ⇒ more KV tiles ⇒ longer task.

PackGQA

In GQA the number of query heads H_q = G·H_kv. Multiple Q heads share the same KV head. Organising tasks by (m_block, request, kv_head) lets a group of Q heads read the KV cache once, reducing memory bandwidth. Because G is usually smaller than BLOCK_M, the M‑direction parallelism remains limited.

SplitKV: Adding Parallelism Along N

Without SplitKV the number of basic tasks is approximately

num_tasks ≈ batch_size × num_kv_heads_local

When batch size is small or tensor‑parallelism reduces the local KV‑head count, this can be far below the number of SMs, leaving many SMs idle.

SplitKV divides a long KV sequence into several contiguous intervals, each handled by a separate CTA. The task count becomes

num_tasks = Σ_b (num_m_blocks[b] × num_kv_heads_local × num_splits[b])
≈ num_kv_heads_local × Σ_b num_splits[b]

SplitKV (1) increases the number of CTA‑level tasks to better occupy SMs, and (2) breaks long requests into smaller pieces to avoid a few heavy tasks dominating the kernel tail. The additional merge step adds some kernel or synchronization overhead.

FlashInfer Task Mapping

CUDA‑core Path

Logical tasks are (request, kv_head) and are launched with a 2‑D grid:

dim3 nblks(padded_batch_size, num_kv_heads);

Each CTA processes the Q heads belonging to the KV head.

Tensor‑core Path

When use_tensor_core=True the same logical tasks are launched with a 3‑D grid:

dim3 nblks(padded_batch_size, 1, num_kv_heads);

The CTA packs the Q heads into the M dimension and uses MMA on Tensor Cores.

Effect of SplitKV

Enabling SplitKV adds a split_idx dimension, turning a task into (request, kv_head, split_idx). The grid becomes

grid = (num_work_items, H_kv_local)
where num_work_items = Σ_b num_splits[b]

Thus the number of CTA tasks grows with the number of splits.

FlashInfer CUDA and Tensor core grid
FlashInfer CUDA and Tensor core grid

Threshold Phenomenon in FlashInfer

Experiments keep the total number of tokens constant while varying the kv_len distribution (Uniform, Skewed 40 %, Skewed 60 %, Skewed 80 %, 2×40 %). When the batch size exceeds a certain threshold, especially for skewed distributions, kernel latency jumps. The CUDA‑core path shows a higher threshold than the Tensor‑core path.

FlashInfer decides whether to enable SplitKV with a heuristic. For the CUDA‑core path it queries the maximum active blocks per SM:

cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem);
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * gdy >= max_grid_size) {
    split_kv = false;
}

where gdy is the number of KV heads on the GPU. The Tensor‑core path targets a fixed number of CTAs per SM (≈ 2) and disables SplitKV when the estimated task count exceeds this target. Consequently, once the batch size passes the heuristic bound, long requests are no longer split and dominate the kernel tail.

Threshold experiment results
Threshold experiment results

FA3: Persistent Kernel Dynamic Scheduling

Work Tile Definition

FA3 treats a variable‑length batch as a varlen workload. A work tile corresponds to (m_block, virtual_batch, kv_head[, split_idx]). Persistent CTAs repeatedly fetch tiles from a global queue until all work is done.

Prepare Kernel

Compute the workload of each request.

Decide per‑request whether KV‑direction splitting is needed.

Reorder requests so that heavier tasks are scheduled earlier (virtual_batch → real_batch mapping).

After preparation, the scheduler dispatches tiles in order; each persistent CTA processes an initial tile based on its block index and then claims the next tile via a global counter.

Static Split Upper Bound

FA3 first computes a static upper bound num_splits_static based on L2 cache size, KV‑head size, and M‑block count. If the estimated M‑blocks exceed 0.8 × num_SM and the KV‑head size > 50 MiB, the static split count may be > 1; otherwise it returns 1.

Per‑request Dynamic Split

For each request the number of N‑blocks is

num_n_blocks[b] = ceil(kv_len[b] / BLOCK_N)

The total blocks and target blocks per SM are

total_blocks = Σ_b num_n_blocks[b]
blocks_per_sm = ceil(total_blocks × 1.1 × H_schedule / num_sm)

The dynamic split count is

num_splits_dynamic[b] = clamp(ceil(num_n_blocks[b] / blocks_per_sm), 1, num_splits_static)

For example, with blocks_per_sm = 16, a request with num_n_blocks = 8 receives no split, while a request with num_n_blocks = 80 is split into five pieces. The prepare kernel can handle at most 992 requests; larger batches fall back to num_splits_dynamic = 1.

Effect

FA3 keeps SplitKV active for long requests even when the batch size grows, unlike FlashInfer’s static heuristic. The combination of per‑request dynamic splitting and persistent scheduling reduces tail latency and improves SM utilisation.

FA3 load‑balancing flowchart
FA3 load‑balancing flowchart

Experimental Comparison

Two FA3 configurations were tested on an H20 GPU:

splits=1 : SplitKV forced off (equivalent to a single split).

splits=0 : FA3 decides dynamically whether to enable SplitKV.

With splits=1 kernel latency grows sharply with the longest sequence length, showing that persistent kernels alone cannot eliminate the tail caused by a single heavy request. With splits=0 latency remains stable for mixed‑length batches, demonstrating that FA3’s dynamic SplitKV effectively breaks long requests into more work tiles. For uniformly short batches occasional latency fluctuations appear due to wave‑quantization introduced by splitting.

Summary

Decode‑batch load imbalance originates from amplified differences in request lengths during task partitioning. Because q_len=1, parallelism in the M direction is limited; the main source of work variation is the N direction ( kv_len). When long KV requests are not split, they become heavy CTAs that extend the kernel tail.

FlashInfer’s observed batch‑size thresholds stem from its SplitKV heuristic: once the estimated task count exceeds a heuristic bound, SplitKV is disabled, which is acceptable for uniform batches but harmful for mixed‑length batches.

FA3 mitigates the problem by retaining per‑request dynamic SplitKV in the varlen path, generating more balanced work tiles, and using a persistent kernel that dynamically schedules tiles. Consequently, the longest request no longer dominates kernel time and overall SM load becomes more uniform.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

LLMFlashInferDecodeBatchFA3Load ImbalanceSplitKV
Machine Learning Algorithms & Natural Language Processing
Written by

Machine Learning Algorithms & Natural Language Processing

Focused on frontier AI technologies, empowering AI researchers' progress.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.