DecodeBatch Load Imbalance in LLM Inference: Request Length Differences Amplify
During LLM decoding, the DecodeBatch stage can suffer severe load imbalance because differing historical token lengths (kv_len) cause uneven attention task distribution across GPU SMs, a problem explored through detailed analysis of task granularity, SplitKV heuristics, FlashInfer’s batch‑size thresholds, and FA3’s dynamic scheduling and split strategies.
Decode Batch Load Imbalance
In a typical GQA architecture the decode stage processes one token per request ( q_len=1). The kernel execution time is roughly proportional to the total number of tokens only when the KV sequence lengths ( kv_len) of all requests are similar. If kv_len varies widely, some SMs finish early while others remain busy, breaking the linear relationship.
Attention Task Partitioning
M/N Blocking
FlashAttention‑style kernels compute S = QKᵀ block‑wise. For a single request and head the tensors are
Q_{b,h}: [1, D]
K_{b,h}: [kv_len, D]
V_{b,h}: [kv_len, D]
S_{b,h} = Q_{b,h} K_{b,h}ᵀ // [1, kv_len] BLOCK_Mis usually 128; BLOCK_N is 128 or 176. Because q_len=1, num_m_blocks = ceil(1 / BLOCK_M) = 1. Variation therefore comes from the N direction ( kv_len).
Basic Attention Task
A basic task is identified by (m_block, request, kv_head). In decode the M direction has only one effective block, so the task iterates over all KV tiles of a request. Longer kv_len ⇒ more KV tiles ⇒ longer task.
PackGQA
In GQA the number of query heads H_q = G·H_kv. Multiple Q heads share the same KV head. Organising tasks by (m_block, request, kv_head) lets a group of Q heads read the KV cache once, reducing memory bandwidth. Because G is usually smaller than BLOCK_M, the M‑direction parallelism remains limited.
SplitKV: Adding Parallelism Along N
Without SplitKV the number of basic tasks is approximately
num_tasks ≈ batch_size × num_kv_heads_localWhen batch size is small or tensor‑parallelism reduces the local KV‑head count, this can be far below the number of SMs, leaving many SMs idle.
SplitKV divides a long KV sequence into several contiguous intervals, each handled by a separate CTA. The task count becomes
num_tasks = Σ_b (num_m_blocks[b] × num_kv_heads_local × num_splits[b])
≈ num_kv_heads_local × Σ_b num_splits[b]SplitKV (1) increases the number of CTA‑level tasks to better occupy SMs, and (2) breaks long requests into smaller pieces to avoid a few heavy tasks dominating the kernel tail. The additional merge step adds some kernel or synchronization overhead.
FlashInfer Task Mapping
CUDA‑core Path
Logical tasks are (request, kv_head) and are launched with a 2‑D grid:
dim3 nblks(padded_batch_size, num_kv_heads);Each CTA processes the Q heads belonging to the KV head.
Tensor‑core Path
When use_tensor_core=True the same logical tasks are launched with a 3‑D grid:
dim3 nblks(padded_batch_size, 1, num_kv_heads);The CTA packs the Q heads into the M dimension and uses MMA on Tensor Cores.
Effect of SplitKV
Enabling SplitKV adds a split_idx dimension, turning a task into (request, kv_head, split_idx). The grid becomes
grid = (num_work_items, H_kv_local)
where num_work_items = Σ_b num_splits[b]Thus the number of CTA tasks grows with the number of splits.
Threshold Phenomenon in FlashInfer
Experiments keep the total number of tokens constant while varying the kv_len distribution (Uniform, Skewed 40 %, Skewed 60 %, Skewed 80 %, 2×40 %). When the batch size exceeds a certain threshold, especially for skewed distributions, kernel latency jumps. The CUDA‑core path shows a higher threshold than the Tensor‑core path.
FlashInfer decides whether to enable SplitKV with a heuristic. For the CUDA‑core path it queries the maximum active blocks per SM:
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem);
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * gdy >= max_grid_size) {
split_kv = false;
}where gdy is the number of KV heads on the GPU. The Tensor‑core path targets a fixed number of CTAs per SM (≈ 2) and disables SplitKV when the estimated task count exceeds this target. Consequently, once the batch size passes the heuristic bound, long requests are no longer split and dominate the kernel tail.
FA3: Persistent Kernel Dynamic Scheduling
Work Tile Definition
FA3 treats a variable‑length batch as a varlen workload. A work tile corresponds to (m_block, virtual_batch, kv_head[, split_idx]). Persistent CTAs repeatedly fetch tiles from a global queue until all work is done.
Prepare Kernel
Compute the workload of each request.
Decide per‑request whether KV‑direction splitting is needed.
Reorder requests so that heavier tasks are scheduled earlier (virtual_batch → real_batch mapping).
After preparation, the scheduler dispatches tiles in order; each persistent CTA processes an initial tile based on its block index and then claims the next tile via a global counter.
Static Split Upper Bound
FA3 first computes a static upper bound num_splits_static based on L2 cache size, KV‑head size, and M‑block count. If the estimated M‑blocks exceed 0.8 × num_SM and the KV‑head size > 50 MiB, the static split count may be > 1; otherwise it returns 1.
Per‑request Dynamic Split
For each request the number of N‑blocks is
num_n_blocks[b] = ceil(kv_len[b] / BLOCK_N)The total blocks and target blocks per SM are
total_blocks = Σ_b num_n_blocks[b]
blocks_per_sm = ceil(total_blocks × 1.1 × H_schedule / num_sm)The dynamic split count is
num_splits_dynamic[b] = clamp(ceil(num_n_blocks[b] / blocks_per_sm), 1, num_splits_static)For example, with blocks_per_sm = 16, a request with num_n_blocks = 8 receives no split, while a request with num_n_blocks = 80 is split into five pieces. The prepare kernel can handle at most 992 requests; larger batches fall back to num_splits_dynamic = 1.
Effect
FA3 keeps SplitKV active for long requests even when the batch size grows, unlike FlashInfer’s static heuristic. The combination of per‑request dynamic splitting and persistent scheduling reduces tail latency and improves SM utilisation.
Experimental Comparison
Two FA3 configurations were tested on an H20 GPU:
splits=1 : SplitKV forced off (equivalent to a single split).
splits=0 : FA3 decides dynamically whether to enable SplitKV.
With splits=1 kernel latency grows sharply with the longest sequence length, showing that persistent kernels alone cannot eliminate the tail caused by a single heavy request. With splits=0 latency remains stable for mixed‑length batches, demonstrating that FA3’s dynamic SplitKV effectively breaks long requests into more work tiles. For uniformly short batches occasional latency fluctuations appear due to wave‑quantization introduced by splitting.
Summary
Decode‑batch load imbalance originates from amplified differences in request lengths during task partitioning. Because q_len=1, parallelism in the M direction is limited; the main source of work variation is the N direction ( kv_len). When long KV requests are not split, they become heavy CTAs that extend the kernel tail.
FlashInfer’s observed batch‑size thresholds stem from its SplitKV heuristic: once the estimated task count exceeds a heuristic bound, SplitKV is disabled, which is acceptable for uniform batches but harmful for mixed‑length batches.
FA3 mitigates the problem by retaining per‑request dynamic SplitKV in the varlen path, generating more balanced work tiles, and using a persistent kernel that dynamically schedules tiles. Consequently, the longest request no longer dominates kernel time and overall SM load becomes more uniform.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Machine Learning Algorithms & Natural Language Processing
Focused on frontier AI technologies, empowering AI researchers' progress.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
