How to Optimize Multi-Head Attention: From MQA to FlashAttention and Beyond

This article examines various techniques for compressing and accelerating the KV cache in transformer models—including MQA, GQA, MLA, sliding‑window and linear attention, flash attention, page and ring attention, as well as mixed‑precision training and ZeRO parallelism—providing code snippets, implementation details, and practical trade‑offs.

NewBeeNLP
NewBeeNLP
NewBeeNLP
How to Optimize Multi-Head Attention: From MQA to FlashAttention and Beyond

KV Cache Compression Techniques

The KV cache size grows with model hidden dimension, number of layers, and sequence length. Reducing the KV footprint is essential for supporting larger models or longer contexts.

Multi‑Query Attention (MQA)

All query heads share a single set of keys and values, shrinking the KV cache by a factor of the number of heads ( h). To keep total parameter count constant, the feed‑forward network (FFN) or GLU width is usually increased, which mitigates most accuracy loss. Models such as PaLM and Gemini use MQA.

Grouped‑Query Attention (GQA)

GQA is a middle ground between MQA and full multi‑head attention: queries are grouped, but each group has its own keys and values. This balances memory savings with expressive power and is employed in LLaMA 2 and Code LLaMA.

Multi‑head Latent Attention (MLA)

DeepSeek‑V2 applies a low‑rank projection to the KV matrices, effectively compressing the cache (the technique is called MLA).

Sliding‑Window Attention (SWA)

Each token attends only to the previous N tokens, yielding a sparse attention pattern. Mistral 7B uses a window size of 4096, allowing longer contexts with modest memory cost.

Linear (Kernel) Attention

Replace the softmax with a kernel function φ(x)=elu(x)+1. Queries and keys are transformed by φ, enabling incremental computation of the QK product. RWKV is a notable variant that compresses historical information into a single vector, similar to an RNN.

MHA Engineering Optimizations (No Accuracy Loss)

In decoder‑only models the KV pairs are cached after each forward step. The following PyTorch‑style loop shows KV accumulation and attention computation:

# q, k, v are the current timestep's query, key, value
# K_prev, V_prev store all previous keys and values
for _ in range(time_step):
    # ... compute q, k, v ...
    K = torch.cat([K_prev, k], dim=-2)   # [b, h, n, d]
    V = torch.cat([V_prev, v], dim=-2)   # [b, h, n, d]
    logits = torch.einsum('bhd,bhnd->bhn', q, K)
    weights = torch.softmax(logits / math.sqrt(d), dim=-1)
    outs = torch.einsum('bhn,bhnd->bhd', weights, V)
    # ... use outs ...
    K_prev, V_prev = K, V

Online Softmax vs. Safe Softmax

Both variants improve numerical stability and memory usage during the softmax step of attention. See Chen Star’s flash‑attention notes for implementation details.

FlashAttention

When model size and sequence length exceed GPU SRAM, the full QK matrix cannot reside in cache. FlashAttention tiles the QK multiplication into SRAM‑resident blocks, dramatically reducing HBM traffic.

IO‑aware tiling minimizes HBM reads/writes.

Intermediate similarity and softmax probability matrices are never written back to HBM.

Example: a 7 B model with hidden size 4096 needs ~512 KB per timestep; a 1024‑token context would require ~512 MB of HBM, far beyond typical GPU SRAM (≈40‑50 MB). FlashAttention became practical with GPUs such as the A100 that provide large SRAM and asynchronous copy instructions.

Page (Paged) Attention (vLLM)

vLLM treats fixed‑size blocks (default 16 tokens) as virtual memory pages. Shared prefixes across requests are stored once, and copy‑on‑write eliminates duplicate KV storage during beam search or parallel sampling. The core operation is attention_ops.single_query_cached_kv_attention, which yields up to 24× higher throughput than vanilla HuggingFace Transformers without model code changes.

Ring Attention

For ultra‑long sequences, the sequence is split across n GPUs. Each GPU processes a block, keeping only its local Q, K, V, while peer‑to‑peer communication exchanges K and V across devices. This enables multi‑GPU inference with linear memory scaling.

Striped Attention

Striped Attention extends Ring Attention by redistributing work among devices to resolve load‑balancing issues inherent to the ring topology.

FFN Optimizations

Approximately two‑thirds of transformer parameters reside in the feed‑forward network. Sparsifying or factorizing the MLP reduces both parameter count and compute, especially for short sequences where the FFN dominates.

Mixture‑of‑Experts (MoE)

MoE routes tokens to a subset of expert FFNs, effectively increasing model capacity without a proportional increase in compute or memory.

Mixed‑Precision Training

FP16 is used for forward and backward passes, while optimizer states and weight updates remain in FP32 (“FP16‑with‑FP32‑master”). This halves memory usage for activations and gradients while preserving ~99 % of full‑precision accuracy.

Forward pass with FP16 weights produces FP16 activations and gradients.

Optimizer computes FP32 weight updates.

Updated FP32 weights are cast back to FP16 for the next iteration.

For numerical stability, implementations (e.g., HuggingFace Transformers, LLaMA) cast the attention logits to FP32 before softmax and cast the result back to FP16 afterwards.

Parallelism Strategies and Training Frameworks

Common parallelism techniques include:

Data parallelism

Model (tensor) parallelism – slices each layer across GPUs.

Pipeline parallelism – partitions layers into stages.

3D parallelism – combines data, pipeline, and tensor parallelism.

Major frameworks supporting these strategies are HuggingFace Transformers, DeepSpeed, and Megatron‑LM.

Megatron‑LM

Megatron‑LM implements tensor parallelism by partitioning each transformer layer across GPUs, achieving up to 24× speed‑up over a non‑parallel baseline without modifying model code.

ZeRO and Offload Strategies

ZeRO shards optimizer states, gradients, and parameters across GPUs:

ZeRO‑1 : Optimizer‑state sharding with All‑Reduce for gradients.

ZeRO‑2 : Gradient sharding plus All‑Reduce.

ZeRO‑3 : Full parameter sharding; All‑Gather is performed before each forward/backward pass.

ZeRO‑Offload moves low‑compute data (FP32 weights, optimizer states) to CPU while keeping high‑compute data (FP16 weights, activations) on GPU. ZeRO‑Infinity extends ZeRO‑3 with CPU offload and optional NVMe storage, enabling training of >100 B‑parameter models on limited GPU memory.

References

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA – https://kexue.fm/archives/10091/comment-page-2

Flash attention & flash decoding (陈star) – https://zhuanlan.zhihu.com/p/691623115

极市开发者平台 – 计算机视觉算法开发落地平台 – https://www.cvmart.net/community/detail/8302

图解大模型计算加速系列:Flash Attention V2 – https://zhuanlan.zhihu.com/p/691067658

Ring attention + flash attention:超长上下文之路 – https://zhuanlan.zhihu.com/p/683714620

Hugging Face source code – https://github.com/huggingface/transformers/blob/ee4250a35f3bd5e9a4379b4907b3d8f9d5d9523f/src/transformers/models/llama/modeling_llama.py#L350C8-L351C111

LLaMA official implementation – https://github.com/meta-llama/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L180C17-L180C42

TransformerFlashAttentionAttentionModel Parallelismmixed precisionKV cache
NewBeeNLP
Written by

NewBeeNLP

Always insightful, always fun

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.