Model
Hardware
GPUs
Quant
Parallelism
Concurrency
Section 3

Attention

Q/K/V projection, the Q@Kᵀ quadratic, softmax, attn@V (the KV-cache reader), and output projection. Where context length pays its bill.

hidden → Q · K · V hidden W_Q Q · 64 heads W_K K · 8 heads W_V V · 8 heads GQA · K / V are 1/8 the size of Q

QKV projection

From hidden, three parallel projections produce Q (queries), K (keys), V (values).

Q · bytes/step5.4 GB
K · bytes/step671.1 MB
V · bytes/step671.1 MB
Read more

Each transformer block reads the residual hidden vector and projects it through three weight matrices — W_Q, W_K, W_V — to produce queries Q (what each token looks for), keys K (what each token offers), and values V (what each token contributes when matched).

In dense MHA the three matrices are the same shape. In Grouped Query Attention (Llama-3, Qwen-2.5), Q has full width but K and V share heads — Llama-3-70B uses 8 KV heads vs 64 Q heads, so W_K and W_V are 1/8 the size of W_Q. That asymmetry is one of the largest single savings any modern model makes; it pulls KV cache down by the same factor and is what makes long-context decode tractable.

Try it: switch to a model with a different num_kv_heads to watch K/V meters shrink.

Try it in the calculator
Q · Kᵀ → scores Q K · N scores no weights · KV bytes scale with N

Q · Kᵀ scores

No weights, but quadratic in N — every query reads the whole KV cache.

FLOPs983.0 M
KV bytes61.4 MB
Read more

Each query Q is dotted with every key K already in the cache, producing one score per cached token. There are no weights here — but the work is **N × N** at prefill (the full causal score matrix) and **1 × N** per decode step (one query against N cached keys).

This is where transformer attention's *quadratic* lives. Doubling the context doubles the per-step KV bytes streamed *and* the FLOPs. For a 128k-token context, the KV cache read at decode dominates everything — even the FFN.

GQA helps: with 8 KV heads instead of 64, the KV cache shrinks 8× and so does the read here. MLA (DeepSeek-V3) compresses further by sharing a small latent KV across heads.

Try it: increase output_tokens and watch the per-step KV bytes climb.

Try it in the calculator
scores → probabilities raw scores softmax probabilities Σ = 1 0 weights · ~free FLOPs

Softmax

Scores normalize row-wise into probabilities — free cost-wise, central to attention.

FLOPs0
Read more

For each query Q, softmax takes the row of raw scores from Q · Kᵀ and converts it into a probability distribution: subtract the max, exponentiate, divide by the sum. The result is one weight per cached token, summing to 1.

The work is tiny — a few FLOPs per score, no parameters, no HBM weight reads. Numerically the phase is *free*; on any roofline plot it sits near the origin.

But it is where attention does its job. The shape of the resulting distribution decides which past tokens contribute and which are ignored. Fused implementations (Flash attention) merge softmax with the surrounding multiplies to avoid materializing the full N × N score matrix in HBM — a memory savings, not a compute one.

Try it: this phase will stay near zero across every scenario.

Try it in the calculator
probs · V → output probs × V · N ↓ Σ out second quadratic · V cache read again

Attn · V

Probabilities weight every cached V vector — the V cache is read all over again.

FLOPs983.0 M
KV bytes61.4 MB
Read more

The probability row from softmax meets the cached V vectors: each cached token contributes its V scaled by its attention weight, and the model sums these across all cached tokens to produce one attended output vector per query.

This is the second half of attention, symmetric with Q · Kᵀ. No weights, no parameters — but **the bytes streamed are real**: the V cache must be read top-to-bottom, just like the K cache was read for scoring. So Q · Kᵀ and Attn · V each contribute the same KV-cache traffic per step, and they sum.

Flash attention fuses this with softmax and Q · Kᵀ into a single tiled pass over the cache, halving the HBM traffic.

Try it: change kv_cache quantization (fp8 → fp16) and watch the KV bytes double.

Try it in the calculator
heads → hidden 64 heads concat W_O hidden × hidden hidden same shape as W_Q · same cost

O projection

Heads concat and project back to hidden — attention's closing matmul.

Params5.37 B
FLOPs10.7 B
Bytes/step5.4 GB
Read more

Each attention head produces an attended output of size head_dim. The full attention sublayer produces num_heads × head_dim values per token; the O projection concatenates them and projects back to hidden_dim through a single weight matrix W_O of shape [hidden × hidden].

Cost-wise W_O is identical to W_Q: same parameter count, same FLOPs per token, same bytes streamed per decode step. In Llama-3-70B that's ~5.4 GB of weight bytes per token at fp8, on top of every other matmul.

This is the closing matmul of attention. Whatever attention computed gets re-mixed across the hidden vector and added back into the residual stream — passing the rich attention signal into the FFN that follows.

Try it: this phase will track Q proj exactly as you change scenarios — same shape.

Try it in the calculator
MHA vs GQAMHA8 Q · 8 KVGQA8 Q · 2 KV

Grouped Query Attention

Many Q heads share a few K/V heads — KV cache shrinks num_heads / num_kv_heads.

K params671.1 M
KV / step61.4 MB
Read more

Multi-Head Attention dedicates one K head and one V head to every Q head: 64 of each in Llama-2-style. Grouped Query Attention shares a single K/V across a group of Q heads — Llama-3-70B uses 64 Q heads but only 8 KV heads, a 8:1 ratio.

The savings show up everywhere KV is involved: the KV cache shrinks 8×, the K/V projection weights shrink 8×, and per-step KV bandwidth at decode shrinks 8×. Long-context decode at scale is essentially impossible without it.

Multi-Query Attention (MQA, num_kv_heads = 1) is the extreme version; GQA is the compromise that preserves quality. MLA goes further still by caching a small latent rather than per-head K/V.

Try it: switch to a model with different num_kv_heads (Mistral-7B: 8, Llama-3-8B: 8) and watch the K/V meters shift.

Try it in the calculator
K/V → latent → K/Vfull K · Vlatent · cachedK · V (reconstructed)

MLA · latent KV

Cache a small latent per token; decompress K/V on demand. KV shrinks again.

KV-A params0
Read more

Multi-head Latent Attention (MLA), introduced in DeepSeek-V2/V3, takes the GQA idea further. Instead of caching per-head K/V vectors, MLA caches a single low-rank *latent* per token; K and V are decompressed on demand at attention time via two small projection matrices.

The KV cache shrinks by another order of magnitude vs GQA — DeepSeek-V3 caches roughly 70 KB per token vs ~250 KB for an equivalent GQA model. That makes long-context inference dramatically cheaper at the cost of a few extra small matmuls per step.

The trade-off lives in the q_a/q_b/kv_a/kv_b projections — they're new compute that GQA doesn't have, but they're tiny relative to the KV bandwidth they save.

Try it: switch to deepseek-v3 to see the MLA-specific phases populate.

Try it in the calculator
naive vs flashnaive · HBMflash · SRAM

Flash attention

Tile the score matrix into SRAM-sized blocks — same FLOPs, far less HBM traffic.

KV / step (logical)61.4 MB
Read more

A naive attention implementation writes the full N × N score matrix to HBM, then reads it back for softmax, then again for the V multiply. That's three round-trips through the slow memory. For 128k contexts the score matrix alone is huge, and the bandwidth cost dwarfs the FLOPs.

Flash attention (Dao et al., 2022; v2 and v3 since) tiles the operation: load a block of Q rows and K columns into on-chip SRAM, compute the partial scores, run a numerically-stable softmax against running statistics, multiply against a block of V, accumulate the output. Repeat for the next block.

Same FLOPs. No materialised N × N matrix. HBM traffic drops to O(N · d) instead of O(N²). Almost every modern transformer kernel uses some Flash variant.

Try it: this is a kernel-level optimization — the library's per-phase numbers count "logical" attention work, but real-world HBM traffic on bandwidth-bound decode is much smaller with Flash kernels.

Try it in the calculator