Demystifying FlashAttention: A Minimalist Derivation of the Algorithm

This article presents a concise, step‑by‑step derivation of FlashAttention, explaining the prerequisite linear‑algebra concepts, the softmax simplifications, and the parallel computation workflow—including the LSE‑enhanced version—so readers can grasp the algorithm’s elegance without heavy mathematics.

Baobao Algorithm Notes
Baobao Algorithm Notes
Baobao Algorithm Notes
Demystifying FlashAttention: A Minimalist Derivation of the Algorithm

Prerequisite Knowledge

Basic linear‑algebra needed: block matrix multiplication O = Q K^T V, softmax formulation softmax(x)=exp(x‑m)/sum(exp(x‑m)) with m = max(x), exponent and logarithm identities exp(a+b)=exp(a)·exp(b), log(ab)=log a+log b, and block‑wise multiplication [A1,A2][V1,V2]=A1V1+A2V2.

Goal

Compute O = softmax(Q·[K1,K2]^T)·[V1,V2] efficiently without materialising the full QK^T matrix. The derivation below shows how the method scales to any number of K / V blocks.

FlashAttention computation (two‑block example)

Step 1 – Compute Q·Kᵀ for each block

X1 = Q × K1ᵀ
X2 = Q × K2ᵀ

Step 2 – Block‑wise maximum

m1 = max(X1)
m2 = max(m1, max(X2))

Step 3 – Numerators of the softmax

alpha = exp(m1 - m2)

(scale factor to align the two blocks)

a1 = exp(X1 - m1)
a2 = exp(X2 - m2)
a1' = a1 × alpha

Step 4 – Denominators of the softmax

d1 = sum(a1) = sum(exp(X1 - m1))
d2 = sum(a2) = sum(exp(X2 - m2))
d1' = d1 × alpha
d12 = d1' + d2 = d1 × alpha + d2

Step 5 – Assemble the output

O1 = a1 × V1 / d1
O2 = a2 × V2 / d2

Final result: O = O1 × d1 / d12 × alpha + O2 × d2 / d12 The formula is symmetric; additional blocks (K3, V3, …) are handled by repeating the same steps and updating alpha, d12, and the running maximum.

Log‑Sum‑Exp (LSE) enhanced version

Keeping the log‑sum‑exp of each block improves numerical stability.

Step 1‑4 (same as above) plus LSE values

lse1 = log(d1) = log‑sum‑exp(X1 - m1)
lse2 = log(d2) = lse1 + m1 - m2

Step 5 – Assemble with LSE

lse12 = lse1 + log(alpha + exp(lse2 - lse1))

Final LSE‑stable output:

O = O1 × exp(lse1 - lse12) + O2 × exp(lse2 - lse12)

When more blocks are added, maintain the running maximum m and the accumulated LSE value using the same recurrence, ensuring stable and memory‑efficient attention computation.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

Large Language ModelsFlashAttentionparallel computingAttention MechanismAlgorithm DerivationSoftmax Optimization
Baobao Algorithm Notes
Written by

Baobao Algorithm Notes

Author of the BaiMian large model, offering technology and industry insights.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.