attention_forward — Causal Multi-Head Attention with Online Softmax
File: 05_attention.wgsl Pipeline step: The heart of the layer. RoPE'd Q and K + V → contextual mixing.
Two kernels:
attention_forward— training (full sequence)attention_decode— inference (single query against KV cache)
Wait, what is this?
You join a group chat late and 200 messages have piled up. To make sense of the current message, you scroll back and read the old ones. But you don't read them all as equally important: a few are right on topic, the rest is chit-chat. In your head you quietly assign each old message a weight — "very relevant", "ignore this" — then blend that weighted mix into what the current message actually means. That's exactly what attention does.
Each token (word piece) puts out a "query": "who should I be listening to?" The other tokens each offer a "key": "I'm relevant if you're after this topic." Matching the query against the keys (a dot product) gives a score, softmax turns those scores into weights that sum to 1, and each token's "value" gets blended by those weights to form the output. So it's a fuzzy lookup: instead of pulling one row by one key out of a hash map, you take a weighted sum over all the rows by how similar they are.
There's one rule: it's causal. A token can only look at tokens before it, never the ones after — just like you can't read a message in the group chat that hasn't been typed yet. Looking ahead is off-limits, because the model is learning to predict the next word; we can't let it peek at the answer.
The hard part: to work out those weights (softmax), you'd normally write all the scores to memory and scan them three times — which gets expensive as the sequence grows. Online softmax instead reads the messages in a single pass, as they stream by, and keeps correcting itself as it goes. Below we'll look at how all this is squeezed into one pass on the GPU.
What Does It Do?
Causal scaled dot-product attention with GQA (Grouped-Query Attention) and online softmax:
score[s, k] = (Q[s] · K[k]) / sqrt(head_dim)
P[s, k] = softmax_k(score[s, k]) (k ≤ s, causal)
O[s] = Σ_k P[s, k] · V[k]The model's contextual mixing mechanism. Token s computes "how much attention it should pay" to each preceding token k (the probability P[s,k]), then produces output via a weighted sum of the V's.
Online Softmax — Why?
Classic softmax takes 3 passes:
m = max(scores)— scan all scoresexp_sum = Σ exp(score - m)— scan againP = exp(score - m) / exp_sum— scan again
If the sequence is long (e.g. 1024) this means 3× memory traffic + scratch buffer for scores. Online softmax turns this into a single pass, with the scratch buffer becoming O(WG + head_dim), not O(seq_len).
GQA (Grouped Query Attention)
n_heads queries, but only n_kv_heads < n_heads key/value heads. Multiple Q heads share the same K/V group:
n_heads = 12, n_kv_heads = 4
→ q_per_group = 3
→ Query head h uses kv_head h/3Advantage: KV cache is 1/q_per_group in size. In our model n_heads=12 / n_kv=4 → KV is ⅓ the size. Little quality loss, large memory savings.
Online Softmax Algorithm — Very Important
It processes the K-blocks within the sequence in order and updates the state incrementally:
state: m (running max), l (running exp-sum), acc[d] (running V·P)
init: m = -∞, l = 0, acc[d] = 0
for each block of K (size WG=256):
1. compute scores = (Q · K_block[k]) * scale for each k in block
2. block_max = max(scores in block)
3. m_new = max(m, block_max)
4. P[k] = exp(score[k] - m_new)
5. block_sum = Σ P[k]
6. correction = exp(m - m_new) ← old state'i scale
7. l ← l * correction + block_sum
8. acc[d] ← acc[d] * correction + Σ_k P[k] · V[k, d]
9. m ← m_new
final:
O[d] = acc[d] / l
LSE = m + log(l) ← saved for backwardThe correction = exp(m_old - m_new) term renormalizes the old state with respect to m_new whenever m_new > m_old. This keeps it stable and single-pass.
LSE = log(Σ exp(scores)) is saved for backward — in the backward pass we reuse it and avoid recomputation.
Bind Group ABI
attention_forward (8 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | Q: array<f32> — [seq × n_heads × head_dim] |
| 1 | storage, read | K: array<f32> — [seq × n_kv × head_dim] (GQA: smaller) |
| 2 | storage, read | V: array<f32> — [seq × n_kv × head_dim] |
| 3 | storage, read_write | O: array<f32> — [seq × n_heads × head_dim] |
| 4 | storage, read_write | LSE: array<f32> — [seq × n_heads] saved for backward |
| 5 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
| 6 | uniform | params: vec4<f32> — (scale, _, _, _) |
| 7 | storage, read | seg: array<u32> — the document start index of each query (cross-doc mask) |
scale = 1 / sqrt(head_dim) is computed on the host.
attention_decode (6 bindings)
For decode, Q is a single row ([n_heads × head_dim]), K/V are read from the cache, and there is no LSE save.
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | Q_one: array<f32> — [n_heads × head_dim] (single row) |
| 1 | storage, read | K_cache: array<f32> — [seq × n_kv × head_dim] (cache of past keys) |
| 2 | storage, read | V_cache: array<f32> — [seq × n_kv × head_dim] (cache of past values) |
| 3 | storage, read_write | O_one: array<f32> — [n_heads × head_dim] (single output row) |
| 4 | uniform | dims_dec: vec4<u32> — (cache_len, n_heads, n_kv_heads, head_dim) |
| 5 | uniform | params_dec: vec4<f32> — (scale, _, _, _) |
Workgroup Memory Layout
var<workgroup> sh_q: array<f32, 256>; // Q row cache
var<workgroup> sh_p: array<f32, 256>; // current K-block scores → P values
var<workgroup> sh_acc: array<f32, 256>; // running output accumulator
var<workgroup> sh_partial: array<f32, 256>; // V·P partial sum reduction scratch
var<workgroup> sh_max: array<f32, 1>; // running max m (single value)
var<workgroup> sh_lse: array<f32, 1>; // running sum-exp l (single value)Total: ~4 KB workgroup memory. That is 12% of the Apple GPU's 32 KB limit — comfortable occupancy.
MAX_HEAD_DIM = 256 constraint: head_dim cannot exceed 256. In our model head_dim=64 (12 heads × 64 = 768 d_model), with plenty of room to spare.
Dispatch Shape
workgroup_size: 256
grid: (seq_len, n_heads, 1) workgroupsOne workgroup = one (s, h) pair, i.e. one query position for one head.
Example (seq=512, n_heads=12):
- 6144 WG × 256 threads = 1.57M threads
- Each WG iterates
ceil(causal_len / 256)blocks of K
Line by Line — attention_forward
1) Decode workgroup ID
let s = wgid.x; // query position
let h = wgid.y; // attention head
// ...
let q_per_group = n_heads / n_kv_heads;
let kv_h = h / q_per_group; // GQA: query head h uses kv_hGQA mapping. h=0,1,2 → kv_h=0; h=3,4,5 → kv_h=1; etc.
2) Phase 0 — Cache Q, init state
let q_row_base = s * d_model + h * head_dim;
let causal_len = s + 1u; // attend over [0, s]
let doc_start = seg[s]; // bu sorgunun ait olduğu dökümanın başlangıç indeksicausal_len = s + 1u→ this query position can attend to all preceding positions including itself, but not to subsequent ones.doc_start = seg[s]→ Cross-document masking: queryscan only attend to keyskwithin its own document boundaries. Keys before the document start (k < doc_start) are masked out.sh_q[i] = Q[...]— pull a head_dim-sized copy of the Q row into workgroup memory. In the inner loop every thread accesses the same Q; load once, read many.
3) K-block parallelization parameters
let k_par = max(1u, WG / max(1u, head_dim)); // 4 when head_dim=64
let n_active = head_dim * k_par; // 256 when head=64,k_par=4
let d_local = tid / k_par; // d ∈ [0, head_dim)
let kc = tid % k_par; // chunk index ∈ [0, k_par)
let chunk_size = (WG + k_par - 1u) / k_par; // 64In Phase D (V·P partial), interpret the WG as 2D: a (d_local, kc) mesh. d_local is the output dimension (0..head_dim-1), kc is the chunk index (0..k_par-1).
If head_dim < WG, multiple threads compute different chunks for the same d in parallel.
4) K-block iteration (Skipping the Pre-Document Region)
let first_blk = doc_start / WG;
let n_blocks = (causal_len + WG - 1u) / WG;
for (var blk = first_blk; blk < n_blocks; blk = blk + 1u) {
let k_start = blk * WG;
let k = k_start + tid;first_blk = doc_start / WG→ K-blocks entirely below the document start are skipped completely! This both saves on memory reads and guarantees that the first processed block contains at least one valid (unmasked) key. This guarantee prevents the running-max value in the online softmax computation from staying at-∞and producing anexp(-∞ - -∞) = NaNerror in the running-max correction (LSE stabilization).
5) Phase A — Score computation (Masked Positions)
var score: f32 = NEG_INF;
if (k >= doc_start && k < causal_len) {
let k_row_base = k * kv_dim + kv_h * head_dim;
var dot: f32 = 0.0;
for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
dot = fma(sh_q[d], K[k_row_base + d], dot);
}
score = dot * scale;
}Each thread computes a score for one k:
- Q (from workgroup mem) ⋅ K[k] (from global mem)
head_dimFMAs- Scale =
1/sqrt(head_dim)(LLaMA-style) if (k >= doc_start && k < causal_len)applies both document masking and the causal mask in a single line. Positions that fail the condition getNEG_INF, so they produceexp(NEG_INF) = 0contribution in the softmax.
Memory pattern: Adjacent threads read adjacent K rows (k=0,1,2,...) — K[k * kv_dim + kv_h * head_dim + d]. For the same d, adjacent k → scattered (kv_dim apart). Not coalesced. This is the performance handicap of attention forward — flash attention forward (perf-doc #1) could fix it.
6) Phase B — Block max + running max
let block_max = wg_reduce_max(tid, score);
let m_old = sh_max[0];
let m_new = max(m_old, block_max);wg_reduce_max returns the max across all 256 threads — the same value in every thread.
7) Phase C — P[k] + block sum + correction
var p_val: f32 = 0.0;
if (k < causal_len) {
p_val = exp(score - m_new);
}
sh_p[tid] = p_val;
let block_sum = wg_reduce_sum(tid, p_val);
let correction = exp(m_old - m_new);
let l_old = sh_lse[0];
let l_new = l_old * correction + block_sum;p_val = exp(score - m_new) — softened score. m_new is the max seen so far; exp(score - max) ≤ 1 for numerical stability.
correction = exp(m_old - m_new):
- First iteration:
m_old = -∞,correction = exp(-∞ - m_new) → 0(zero out the old state, since there was no information then) - Subsequent iterations: if m grows,
correction < 1, down-scaling the old accumulators
8) Phase D — V·P contribution
if (tid < n_active) {
let kl_start = kc * chunk_size;
let kl_end_raw = kl_start + chunk_size;
let kl_end = select(kl_end_raw, WG, kl_end_raw > WG);
var partial: f32 = 0.0;
for (var kl = kl_start; kl < kl_end; kl = kl + 1u) {
let kk = k_start + kl;
if (kk < causal_len) {
partial = fma(sh_p[kl], V[kk * kv_dim + kv_h * head_dim + d_local], partial);
}
}
sh_partial[d_local * k_par + kc] = partial;
}
workgroupBarrier();
if (kc == 0u && tid < n_active) {
var sum_pv: f32 = sh_partial[d_local * k_par];
for (var i = 1u; i < k_par; i = i + 1u) {
sum_pv = sum_pv + sh_partial[d_local * k_par + i];
}
sh_acc[d_local] = sh_acc[d_local] * correction + sum_pv;
}Phase D has two steps:
- Each
(d_local, kc)thread computes the partialΣ P[kl] · V[k_start+kl, d_local]over its own chunk - After the barrier, the
kc == 0uthreads sum the partials for theird_local(k_par of them) → write tosh_acc[d_local]
Why two phases? If we did the Σ_kl directly in a single thread, head_dim threads would run while the others sat idle. Multi-thread reduction: 4 (k_par) threads produce partials in parallel → 1 thread merges. Faster.
9) Update m, l
if (tid == 0u) {
sh_max[0] = m_new;
sh_lse[0] = l_new;
}
workgroupBarrier();Only tid 0 writes. Single writer, no race. After the barrier all threads see the updated m/l.
10) Final O and LSE
let l_final = max(sh_lse[0], 1e-30);
let inv_l = 1.0 / l_final;
if (tid < head_dim) {
O[s * d_model + h * head_dim + tid] = nan_guard(sh_acc[tid] * inv_l);
}
if (tid == 0u) {
LSE[s * n_heads + h] = sh_max[0] + log(l_final);
}max(l, 1e-30) — division-by-zero protection. If no K was hit, l could stay 0 (for example the causal_len=0 edge case, which does not include s=0). In practice it is never triggered.
O[d] = acc[d] / l — softmax denominator division.
LSE = m + log(l) — log-sum-exp, to be reused in backward.
How attention_decode Differs
For inference. Q is a single row (the current token), K/V are read from the cache (all previous tokens'):
let h = wgid.x; // bir head için bir WG (n_heads workgroup)
// ... causal mask yok, tüm cache_len attend- No LSE save (no backward in decode)
- Smaller workgroup grid (
n_headsnotseq × n_heads) - KV is a pre-loaded cache; our KV buffer may not be a cache but storage rewritten every step
Performance Notes
This kernel is the bottleneck. In the profile snapshot:
attention_forward: 8.7% of step time (45 ms / 12 layers)- Each layer ~3.75 ms
Reasons:
- The K access pattern is not coalesced (adjacent threads read K[k], with k varying)
- The V access pattern has the same problem
- The causal mask leaves threads idle in the last blocks
Flash-Attention Experiments and Occupancy Loss:
A Flash-style Q-block tiled variant (Br=8, Bc=32, K/V tile cache) was tried in the codebase. However, on Apple GPUs this approach required 21 KB of workgroup memory, which dropped the occupancy ratio from ~32 WG/SM to ~1 WG/SM and caused more loss of latency hiding than the gain obtained from amortizing K/V loads. Since it caused a net +47% increase in step time in the d=128 h=4 kv=2 swiglu seq=512 configuration, the kernel was reverted to the per-query online softmax structure.
Optimization opportunity: the Flash-attention pattern (tile Q and K, K/V over chunks of Q). Perf-doc #1.
WGSL-Specific Notes
1. wg_reduce_max and wg_reduce_sum give the same result in every thread
Detailed in 00_shared.md. Subgroup-accelerated, deterministic.
2. exp(NEG_INF) = 0
A WGSL spec guarantee. After the causal mask, a NEG_INF score → exp → 0 → zero contribution to the softmax. Correct behaviour.
3. var<workgroup> 4 KB total
sh_q + sh_p + sh_acc + sh_partial = 4×256×4B = 4 KB. sh_max, sh_lse 1×4B = negligible. 12% of the 32 KB limit.
4. select(b, a, cond) — the WGSL ternary
let kl_end = select(kl_end_raw, WG, kl_end_raw > WG);"If raw > WG, take WG; otherwise take raw." Like Min(raw, WG), but explicit.
Code Review
Finding 1: K access non-coalesced
| Risk | Description |
|---|---|
| 🟡 perf | The score loop has each thread read K[k * kv_dim + ...] — k differs per thread, so adjacent threads read from memory kv_dim bytes apart. Not coalesced. Practical impact: ~10-15% attention overhead. |
Mitigation: Flash-attention pattern. Currently open work.
Finding 2: sh_p size 256 — head_dim > 256 fails
| Risk | Description |
|---|---|
| 🟢 none | The MAX_HEAD_DIM = 256 constraint guards against this. Our model is 64, could become 128 in the future, still safe. |
Finding 3: nan_guard(O[d]) at the output
| Risk | Description |
|---|---|
| 🟢 none | If Q × K = ∞, exp can blow up. nan_guard pulls NaN/Inf to 0, keeping downstream stable. |
Quick Checklist
| Test Scenario | Status |
|---|---|
| Is the causal mask correct? (s = 0 attends only to itself) | ✅ k < causal_len check |
| Is the GQA mapping correct? | ✅ kv_h = h / q_per_group |
| Online softmax = standard softmax? | ⚠ no formal test, but loss curve reasonable in practice |
head_dim = 256 edge case crash? | ✅ MAX_HEAD_DIM = 256 |
seq_len = 1 (decode-style) | ✅ skips the block iteration loop |
| Is backward LSE used correctly? | ✅ proof in 10_backward_attention.md |
Next
06_activation.md — GeLU, SwiGLU. The non-linearity component of the FFN.