Streaming softmax: the recurrence that makes FlashAttention work
A two-state recurrence (log-sum-exp with a running-max correction factor) turned 200K-token context windows from a hardware fantasy into a routine training run.
TL;DR. Softmax has a denominator that depends on every input, which is why you cannot stream it the way you stream a running mean. A two-state recurrence converts it into a streaming algorithm with memory per element. Published in this form by Milakov and Gimelshein at NVIDIA in 2018
The puzzle
Suppose I hand you a million numbers one at a time, and at every step ask you to commit to a running answer. If the answer is a plain average, the problem is easy. Hold a sum and a count. When a new number arrives, add it to the sum and bump the count by one. The average at any moment is sum divided by count. Two scalars carry every piece of information that matters, no matter how long the stream gets.
Now change the rule. The function I want at every step is called softmax.
Softmax is a function that takes a list of numbers and turns it into a list of weights: positive numbers that sum to exactly 1. You can read these weights two ways: as a probability distribution over the inputs, or as relative shares, where every input ends up with some share of a fixed total of 1 and bigger inputs claim bigger shares. The formula is
Live · softmax
Eight random logits (top row) softmaxed into eight probabilities (bottom row, always summing to 1). The highlighted column is the winner. Resampled every twenty seconds.
The numerator exponentiates the input we want a weight for. The denominator exponentiates every input and adds them up. Dividing one by the other guarantees the weights add up to 1.
A worked example. For the input :
Three takeaways from this example. First, the weights are positive and sum to 1, as promised: . Second, the largest input () ends up with the largest weight (), about two thirds of the total share, even though its raw value is only 50% bigger than the next input. Softmax sharply favors the biggest input. Third, the gap between consecutive weights is exponential: input 3 is 1 larger than input 2, and gets its weight; input 3 is 2 larger than input 1, and gets its weight. A small change in the input means a large change in the weight.
To see the third point in your hands, drag the third value in the figure below. Push it well above the others and softmax sharpens toward a one-hot distribution; pull all three level and the weights collapse to uniform. Take it up to 6 and the last weight absorbs nearly everything; pull it down to 1 and the three weights become roughly equal.
softmax( [ 1, 2, 3.0 ] ) = [ 0.090, 0.245, 0.665 ]
So softmax converts a list of raw numbers into a list of weights that emphasize the largest entries. With that in hand, here is the puzzle. I hand you the numbers one at a time, and at every step you have to tell me the softmax weight of the number you just saw. Look at the formula again. The denominator sums over every exponentiated input, including ones that have not arrived yet. You cannot give me the weight for the first number until you have seen the last. The memory wall is in the formula itself.
This puzzle is not abstract. It is the bottleneck that shaped transformer attention for half a decade. The small recurrence that resolves it is what FlashAttention is built on (Figure 1).
Where this shows up
Transformers, the architecture behind every modern language model, use a mechanism called attention. When the model processes the word bank in the sentence “she walked along the river bank,” it should pay more weight to river than to walked, because river disambiguates which kind of bank. Attention is how that weighting gets computed.
Concretely, each word is assigned a query vector and a key vector (just lists of numbers). For every pair of words the model computes a similarity score, . For a sequence of length this produces an matrix of scores. Each row of that matrix is passed through softmax to turn the scores into weights that sum to 1, and the weights are used to take a weighted average of a third set of vectors called values.
The bottleneck lives in the shape of that score matrix. At tokens (a modest context by 2026 standards), one score matrix holds 64 million numbers. A frontier transformer stacks dozens of these (typically 96 layers), and inside each layer the same computation runs 64 times in parallel under different parameters; these parallel copies are called attention heads. The score matrices alone come to hundreds of billions of numbers per forward pass, orders of magnitude past what fits in the GPU’s fast on-chip caches. They spill into HBM (High Bandwidth Memory, the GPU’s larger but slower main memory), and attention becomes bandwidth-bound: the time it takes is dominated by how fast we can move data, not by how fast we can do the arithmetic.
FlashAttention
The recurrence
Two ingredients. The first is a numerical trick most people meet in floating-point arithmetic class. The second is the recurrence that turns it into a streaming algorithm.
Ingredient 1: subtract the max. The expression is numerically dangerous. For the numerator already exceeds the limit of float32; for you get even though the answer is well-defined. The fix is algebraic: subtract any constant from every input and the softmax is unchanged.
Picking keeps every exponent at most 0, so the largest term is exactly 1 and nothing overflows. This is the log-sum-exp trick, present in numerical analysis textbooks for decades. Modern attention kernels still subtract the max at every block.
Ingredient 2: turn it into a recurrence. Process the input one element at a time and maintain two running quantities:
- : the running maximum after seeing the first elements
- : the running sum
When a new element arrives, both quantities update. The max is easy: . The sum requires more care. The old terms inside were normalized by the old max ; if the new element is larger, the old max is wrong and every old term needs to be rescaled. The algebra gives the update:
The factor is the correction. When the max grows it is less than 1 and shrinks the old contribution by exactly the right amount; when the max does not grow it is 1 and nothing changes. After the last element, reconstructs the exact answer using only the two scalars.
The recurrence in this form was published by Milakov and Gimelshein at NVIDIA in 2018
Worked example, (Figure 2):
- Step 1. See . Set , .
- Step 2. See . New max is , old sum is rescaled: .
- Step 3. See . New max is : .
- Reconstruct. .
Identical to the textbook computation, with only two scalars in memory at any time.
From softmax to attention: adding the output
Attention does not just want the softmax weights; it wants the weighted sum of the value vectors. Maintain a third running quantity:
- : the running output
where is the value vector for element . The update is the same idea: when the max grows or a new element arrives, rescale and add.
The first term carries the old output forward, rescaled so its implicit normalizer matches the new . The second term adds the new value, weighted by its softmax probability under the new normalization.
The full state per query row is the triple : three scalars (well, two scalars and one short vector), regardless of how long the input is (Figure 3). Process the entire input in a single pass and the triple becomes the final attention output. The score matrix is never materialized.
In FlashAttention this triple lives per-tile rather than per-element: the kernel processes blocks of keys at a time, keeps the running for each query row in fast SRAM, and uses HBM only to stream blocks in and out. The math does not change; only the granularity of “an element” does.
The code, end to end
A clean numpy implementation is short enough to read in one sitting. The streaming version produces output identical to the dense version to within floating-point noise:
import numpy as np
def streaming_attention(Q, K, V):
"""softmax(Q @ K^T) @ V, computed without materializing the score matrix."""
N, d_v = V.shape
m = -np.inf # running max
l = 0.0 # running normalizer
o = np.zeros(d_v) # running output
for i in range(N):
x = Q @ K[i] # one score at a time
m_new = max(m, x)
rescale = np.exp(m - m_new)
l_new = l * rescale + np.exp(x - m_new)
o = o * (l * rescale / l_new) + V[i] * (np.exp(x - m_new) / l_new)
m, l = m_new, l_new
return o
def dense_attention(Q, K, V):
"""Reference: materialize the full score vector."""
scores = K @ Q
weights = np.exp(scores - scores.max())
weights /= weights.sum()
return weights @ V
rng = np.random.default_rng(0)
N, d, d_v = 1024, 64, 128
Q = rng.standard_normal(d)
K = rng.standard_normal((N, d))
V = rng.standard_normal((N, d_v))
streaming = streaming_attention(Q, K, V)
dense = dense_attention(Q, K, V)
print(f"Max abs diff: {np.max(np.abs(streaming - dense)):.2e}")
# Max abs diff: 2.84e-15
Thirty lines, no GPU, no kernel work, and the recurrence is right there in the loop body. The FlashAttention CUDA kernel does the same arithmetic at the block level, with the running triple held in registers and shared memory rather than Python locals. The minimal CUDA version (referenced at the end of this post) is about 60 lines of device code.
Why this matters in production
The accounting is brutal. Dense attention needs memory for the score matrix; streaming attention needs memory for the running triple, where is the head dimension and is fixed at training time (usually 64 or 128). For a 1-million-token context at fp16, the dense score matrix is roughly 2 TB per head per layer; the streaming version uses around 256 bytes per query row (Figure 4).
This is the reason long-context models exist at all. Claude’s 200K window, Gemini’s 1M window, and the experimental 10M-token runs published in late 2024 all assume an attention implementation that does not materialize the score matrix. Before FlashAttention, training a transformer at 16K context already required gradient checkpointing tricks; today the same model trains at 200K without those tricks because the bottleneck moved from softmax memory to data movement.
Compute does not change. Dense and streaming attention do the same multiply-adds. What changes is the memory bandwidth the algorithm demands. Streaming attention reads each key once into fast on-chip memory (SRAM, the small but very fast cache sitting right next to the compute units) and discards it; dense attention writes the entire score matrix out to HBM and reads it back, which is roughly slower than touching SRAM on an H100. The recurrence converts a memory-bound op into a compute-bound one (Figure 5).
The lineage matters because it is not romantic. The recurrence was not invented for transformers. Milakov and Gimelshein (NVIDIA, 2018) published it for a softmax kernel a year before BERT. Rabe and Staats (Google, 2021) adapted it for self-attention. Dao et al. (Stanford, 2022) wrapped it in IO-aware tiling, named it FlashAttention, and earned the spotlight. The headline contribution of FlashAttention is the tiling and the kernel engineering; the load-bearing math is older.
Caveats and further reading
A few honest notes for anyone implementing this.
Backwards pass is harder. The forwards pass is the recurrence above. The backwards pass needs the softmax weights to compute gradients, which means either saving the per-block statistics so the weights can be recomputed during backprop, or recomputing the forward pass entirely. FlashAttention picks the recompute path. This is the part of the paper that took the most engineering and is rarely discussed in summaries.
The recurrence is masking-agnostic. Causal masking, ALiBi, sliding-window, and most attention variants decompose into “compute scores, apply softmax, weight values,” so the streaming form survives all of them. Linear-attention variants (Performer, Linformer, RetNet) change the underlying operation and follow different math entirely.
FlashAttention 2 and 3 are not new math. v2
Numerical care in bf16. The correction factor can underflow in low precision when one new element is much larger than the running max. Production kernels clamp the exponent and use mixed-precision accumulation (compute in bf16, accumulate in fp32) to keep this stable.
Three papers and one implementation are worth the read:
- Milakov & Gimelshein, Online normalizer calculation for softmax (2018): the cleanest derivation of the recurrence.
- Rabe & Staats, Self-attention does not need O(n²) memory (2021): applies it to attention.
- Dao et al., FlashAttention (2022): the IO-aware kernel that made it the default.
github.com/tspeterkim/flash-attention-minimal: a 60-line annotated CUDA implementation worth running.
Worth the time it takes to internalize. The idea is small; the consequences are not.
If you have implemented streaming attention from scratch, or hit a corner of it where the math broke down, what was the hardest part for you? The backwards pass? Tile-size selection? The bf16 numerics around the exp call? I would like to know what surprised you.