llm.istanbul·Study
TR EN
Workbench →

attention_forward — Causal Multi-Head Attention with Online Softmax

File: 05_attention.wgsl Pipeline step: The heart of the layer. RoPE'd Q and K + V → contextual mixing.

Two kernels:

  • attention_forward — training (full sequence)
  • attention_decode — inference (single query against KV cache)

Wait, what is this?

You join a group chat late and 200 messages have piled up. To make sense of the current message, you scroll back and read the old ones. But you don't read them all as equally important: a few are right on topic, the rest is chit-chat. In your head you quietly assign each old message a weight — "very relevant", "ignore this" — then blend that weighted mix into what the current message actually means. That's exactly what attention does.

Each token (word piece) puts out a "query": "who should I be listening to?" The other tokens each offer a "key": "I'm relevant if you're after this topic." Matching the query against the keys (a dot product) gives a score, softmax turns those scores into weights that sum to 1, and each token's "value" gets blended by those weights to form the output. So it's a fuzzy lookup: instead of pulling one row by one key out of a hash map, you take a weighted sum over all the rows by how similar they are.

There's one rule: it's causal. A token can only look at tokens before it, never the ones after — just like you can't read a message in the group chat that hasn't been typed yet. Looking ahead is off-limits, because the model is learning to predict the next word; we can't let it peek at the answer.

The hard part: to work out those weights (softmax), you'd normally write all the scores to memory and scan them three times — which gets expensive as the sequence grows. Online softmax instead reads the messages in a single pass, as they stream by, and keeps correcting itself as it goes. Below we'll look at how all this is squeezed into one pass on the GPU.


What Does It Do?

Causal scaled dot-product attention with GQA (Grouped-Query Attention) and online softmax:

score[s, k] = (Q[s] · K[k]) / sqrt(head_dim)
P[s, k] = softmax_k(score[s, k])    (k ≤ s, causal)
O[s] = Σ_k P[s, k] · V[k]

The model's contextual mixing mechanism. Token s computes "how much attention it should pay" to each preceding token k (the probability P[s,k]), then produces output via a weighted sum of the V's.

Online Softmax — Why?

Classic softmax takes 3 passes:

  1. m = max(scores) — scan all scores
  2. exp_sum = Σ exp(score - m) — scan again
  3. P = exp(score - m) / exp_sum — scan again

If the sequence is long (e.g. 1024) this means 3× memory traffic + scratch buffer for scores. Online softmax turns this into a single pass, with the scratch buffer becoming O(WG + head_dim), not O(seq_len).

GQA (Grouped Query Attention)

n_heads queries, but only n_kv_heads < n_heads key/value heads. Multiple Q heads share the same K/V group:

n_heads = 12, n_kv_heads = 4
→ q_per_group = 3
→ Query head h uses kv_head h/3

Advantage: KV cache is 1/q_per_group in size. In our model n_heads=12 / n_kv=4 → KV is ⅓ the size. Little quality loss, large memory savings.


Online Softmax Algorithm — Very Important

It processes the K-blocks within the sequence in order and updates the state incrementally:

state: m (running max), l (running exp-sum), acc[d] (running V·P)
init: m = -∞, l = 0, acc[d] = 0

for each block of K (size WG=256):
  1. compute scores = (Q · K_block[k]) * scale  for each k in block
  2. block_max = max(scores in block)
  3. m_new = max(m, block_max)
  4. P[k] = exp(score[k] - m_new)
  5. block_sum = Σ P[k]
  6. correction = exp(m - m_new)             ← old state'i scale
  7. l ← l * correction + block_sum
  8. acc[d] ← acc[d] * correction + Σ_k P[k] · V[k, d]
  9. m ← m_new

final:
  O[d] = acc[d] / l
  LSE = m + log(l)                            ← saved for backward

The correction = exp(m_old - m_new) term renormalizes the old state with respect to m_new whenever m_new > m_old. This keeps it stable and single-pass.

LSE = log(Σ exp(scores)) is saved for backward — in the backward pass we reuse it and avoid recomputation.


Bind Group ABI

attention_forward (8 bindings)

BindingTypeDetail
0storage, readQ: array<f32>[seq × n_heads × head_dim]
1storage, readK: array<f32>[seq × n_kv × head_dim] (GQA: smaller)
2storage, readV: array<f32>[seq × n_kv × head_dim]
3storage, read_writeO: array<f32>[seq × n_heads × head_dim]
4storage, read_writeLSE: array<f32>[seq × n_heads] saved for backward
5uniformdims: vec4<u32>(seq_len, n_heads, n_kv_heads, head_dim)
6uniformparams: vec4<f32>(scale, _, _, _)
7storage, readseg: array<u32> — the document start index of each query (cross-doc mask)

scale = 1 / sqrt(head_dim) is computed on the host.

attention_decode (6 bindings)

For decode, Q is a single row ([n_heads × head_dim]), K/V are read from the cache, and there is no LSE save.

BindingTypeDetail
0storage, readQ_one: array<f32>[n_heads × head_dim] (single row)
1storage, readK_cache: array<f32>[seq × n_kv × head_dim] (cache of past keys)
2storage, readV_cache: array<f32>[seq × n_kv × head_dim] (cache of past values)
3storage, read_writeO_one: array<f32>[n_heads × head_dim] (single output row)
4uniformdims_dec: vec4<u32>(cache_len, n_heads, n_kv_heads, head_dim)
5uniformparams_dec: vec4<f32>(scale, _, _, _)

Workgroup Memory Layout

wgsl
var<workgroup> sh_q:       array<f32, 256>;  // Q row cache
var<workgroup> sh_p:       array<f32, 256>;  // current K-block scores → P values
var<workgroup> sh_acc:     array<f32, 256>;  // running output accumulator
var<workgroup> sh_partial: array<f32, 256>;  // V·P partial sum reduction scratch
var<workgroup> sh_max:     array<f32, 1>;    // running max m (single value)
var<workgroup> sh_lse:     array<f32, 1>;    // running sum-exp l (single value)

Total: ~4 KB workgroup memory. That is 12% of the Apple GPU's 32 KB limit — comfortable occupancy.

MAX_HEAD_DIM = 256 constraint: head_dim cannot exceed 256. In our model head_dim=64 (12 heads × 64 = 768 d_model), with plenty of room to spare.


Dispatch Shape

workgroup_size: 256
grid:           (seq_len, n_heads, 1) workgroups

One workgroup = one (s, h) pair, i.e. one query position for one head.

Example (seq=512, n_heads=12):

  • 6144 WG × 256 threads = 1.57M threads
  • Each WG iterates ceil(causal_len / 256) blocks of K

Line by Line — attention_forward

1) Decode workgroup ID

wgsl
let s = wgid.x;        // query position
let h = wgid.y;        // attention head
// ...
let q_per_group = n_heads / n_kv_heads;
let kv_h = h / q_per_group;          // GQA: query head h uses kv_h

GQA mapping. h=0,1,2 → kv_h=0; h=3,4,5 → kv_h=1; etc.

2) Phase 0 — Cache Q, init state

wgsl
let q_row_base = s * d_model + h * head_dim;
let causal_len = s + 1u;             // attend over [0, s]
let doc_start = seg[s];              // bu sorgunun ait olduğu dökümanın başlangıç indeksi
  • causal_len = s + 1u → this query position can attend to all preceding positions including itself, but not to subsequent ones.
  • doc_start = seg[s] → Cross-document masking: query s can only attend to keys k within its own document boundaries. Keys before the document start (k < doc_start) are masked out.
  • sh_q[i] = Q[...] — pull a head_dim-sized copy of the Q row into workgroup memory. In the inner loop every thread accesses the same Q; load once, read many.

3) K-block parallelization parameters

wgsl
let k_par      = max(1u, WG / max(1u, head_dim));  // 4 when head_dim=64
let n_active   = head_dim * k_par;                  // 256 when head=64,k_par=4
let d_local    = tid / k_par;                       // d ∈ [0, head_dim)
let kc         = tid % k_par;                       // chunk index ∈ [0, k_par)
let chunk_size = (WG + k_par - 1u) / k_par;         // 64

In Phase D (V·P partial), interpret the WG as 2D: a (d_local, kc) mesh. d_local is the output dimension (0..head_dim-1), kc is the chunk index (0..k_par-1).

If head_dim < WG, multiple threads compute different chunks for the same d in parallel.

4) K-block iteration (Skipping the Pre-Document Region)

wgsl
let first_blk = doc_start / WG;
let n_blocks = (causal_len + WG - 1u) / WG;
for (var blk = first_blk; blk < n_blocks; blk = blk + 1u) {
    let k_start = blk * WG;
    let k = k_start + tid;
  • first_blk = doc_start / WG → K-blocks entirely below the document start are skipped completely! This both saves on memory reads and guarantees that the first processed block contains at least one valid (unmasked) key. This guarantee prevents the running-max value in the online softmax computation from staying at -∞ and producing an exp(-∞ - -∞) = NaN error in the running-max correction (LSE stabilization).

5) Phase A — Score computation (Masked Positions)

wgsl
var score: f32 = NEG_INF;
if (k >= doc_start && k < causal_len) {
    let k_row_base = k * kv_dim + kv_h * head_dim;
    var dot: f32 = 0.0;
    for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
        dot = fma(sh_q[d], K[k_row_base + d], dot);
    }
    score = dot * scale;
}

Each thread computes a score for one k:

  • Q (from workgroup mem) ⋅ K[k] (from global mem)
  • head_dim FMAs
  • Scale = 1/sqrt(head_dim) (LLaMA-style)
  • if (k >= doc_start && k < causal_len) applies both document masking and the causal mask in a single line. Positions that fail the condition get NEG_INF, so they produce exp(NEG_INF) = 0 contribution in the softmax.

Memory pattern: Adjacent threads read adjacent K rows (k=0,1,2,...) — K[k * kv_dim + kv_h * head_dim + d]. For the same d, adjacent k → scattered (kv_dim apart). Not coalesced. This is the performance handicap of attention forward — flash attention forward (perf-doc #1) could fix it.

6) Phase B — Block max + running max

wgsl
let block_max = wg_reduce_max(tid, score);
let m_old = sh_max[0];
let m_new = max(m_old, block_max);

wg_reduce_max returns the max across all 256 threads — the same value in every thread.

7) Phase C — P[k] + block sum + correction

wgsl
var p_val: f32 = 0.0;
if (k < causal_len) {
    p_val = exp(score - m_new);
}
sh_p[tid] = p_val;
let block_sum = wg_reduce_sum(tid, p_val);
let correction = exp(m_old - m_new);
let l_old = sh_lse[0];
let l_new = l_old * correction + block_sum;

p_val = exp(score - m_new) — softened score. m_new is the max seen so far; exp(score - max) ≤ 1 for numerical stability.

correction = exp(m_old - m_new):

  • First iteration: m_old = -∞, correction = exp(-∞ - m_new) → 0 (zero out the old state, since there was no information then)
  • Subsequent iterations: if m grows, correction < 1, down-scaling the old accumulators

8) Phase D — V·P contribution

wgsl
if (tid < n_active) {
    let kl_start   = kc * chunk_size;
    let kl_end_raw = kl_start + chunk_size;
    let kl_end     = select(kl_end_raw, WG, kl_end_raw > WG);

    var partial: f32 = 0.0;
    for (var kl = kl_start; kl < kl_end; kl = kl + 1u) {
        let kk = k_start + kl;
        if (kk < causal_len) {
            partial = fma(sh_p[kl], V[kk * kv_dim + kv_h * head_dim + d_local], partial);
        }
    }
    sh_partial[d_local * k_par + kc] = partial;
}
workgroupBarrier();

if (kc == 0u && tid < n_active) {
    var sum_pv: f32 = sh_partial[d_local * k_par];
    for (var i = 1u; i < k_par; i = i + 1u) {
        sum_pv = sum_pv + sh_partial[d_local * k_par + i];
    }
    sh_acc[d_local] = sh_acc[d_local] * correction + sum_pv;
}

Phase D has two steps:

  1. Each (d_local, kc) thread computes the partial Σ P[kl] · V[k_start+kl, d_local] over its own chunk
  2. After the barrier, the kc == 0u threads sum the partials for their d_local (k_par of them) → write to sh_acc[d_local]

Why two phases? If we did the Σ_kl directly in a single thread, head_dim threads would run while the others sat idle. Multi-thread reduction: 4 (k_par) threads produce partials in parallel → 1 thread merges. Faster.

9) Update m, l

wgsl
if (tid == 0u) {
    sh_max[0] = m_new;
    sh_lse[0] = l_new;
}
workgroupBarrier();

Only tid 0 writes. Single writer, no race. After the barrier all threads see the updated m/l.

10) Final O and LSE

wgsl
let l_final = max(sh_lse[0], 1e-30);
let inv_l = 1.0 / l_final;
if (tid < head_dim) {
    O[s * d_model + h * head_dim + tid] = nan_guard(sh_acc[tid] * inv_l);
}
if (tid == 0u) {
    LSE[s * n_heads + h] = sh_max[0] + log(l_final);
}

max(l, 1e-30) — division-by-zero protection. If no K was hit, l could stay 0 (for example the causal_len=0 edge case, which does not include s=0). In practice it is never triggered.

O[d] = acc[d] / l — softmax denominator division. LSE = m + log(l) — log-sum-exp, to be reused in backward.


How attention_decode Differs

For inference. Q is a single row (the current token), K/V are read from the cache (all previous tokens'):

wgsl
let h = wgid.x;             // bir head için bir WG (n_heads workgroup)
// ... causal mask yok, tüm cache_len attend
  • No LSE save (no backward in decode)
  • Smaller workgroup grid (n_heads not seq × n_heads)
  • KV is a pre-loaded cache; our KV buffer may not be a cache but storage rewritten every step

Performance Notes

This kernel is the bottleneck. In the profile snapshot:

  • attention_forward: 8.7% of step time (45 ms / 12 layers)
  • Each layer ~3.75 ms

Reasons:

  1. The K access pattern is not coalesced (adjacent threads read K[k], with k varying)
  2. The V access pattern has the same problem
  3. The causal mask leaves threads idle in the last blocks
Note

Flash-Attention Experiments and Occupancy Loss: A Flash-style Q-block tiled variant (Br=8, Bc=32, K/V tile cache) was tried in the codebase. However, on Apple GPUs this approach required 21 KB of workgroup memory, which dropped the occupancy ratio from ~32 WG/SM to ~1 WG/SM and caused more loss of latency hiding than the gain obtained from amortizing K/V loads. Since it caused a net +47% increase in step time in the d=128 h=4 kv=2 swiglu seq=512 configuration, the kernel was reverted to the per-query online softmax structure.

Optimization opportunity: the Flash-attention pattern (tile Q and K, K/V over chunks of Q). Perf-doc #1.


WGSL-Specific Notes

1. wg_reduce_max and wg_reduce_sum give the same result in every thread

Detailed in 00_shared.md. Subgroup-accelerated, deterministic.

2. exp(NEG_INF) = 0

A WGSL spec guarantee. After the causal mask, a NEG_INF score → exp → 0 → zero contribution to the softmax. Correct behaviour.

3. var<workgroup> 4 KB total

sh_q + sh_p + sh_acc + sh_partial = 4×256×4B = 4 KB. sh_max, sh_lse 1×4B = negligible. 12% of the 32 KB limit.

4. select(b, a, cond) — the WGSL ternary

wgsl
let kl_end = select(kl_end_raw, WG, kl_end_raw > WG);

"If raw > WG, take WG; otherwise take raw." Like Min(raw, WG), but explicit.


Code Review

Finding 1: K access non-coalesced

RiskDescription
🟡 perfThe score loop has each thread read K[k * kv_dim + ...]k differs per thread, so adjacent threads read from memory kv_dim bytes apart. Not coalesced. Practical impact: ~10-15% attention overhead.

Mitigation: Flash-attention pattern. Currently open work.

Finding 2: sh_p size 256 — head_dim > 256 fails

RiskDescription
🟢 noneThe MAX_HEAD_DIM = 256 constraint guards against this. Our model is 64, could become 128 in the future, still safe.

Finding 3: nan_guard(O[d]) at the output

RiskDescription
🟢 noneIf Q × K = ∞, exp can blow up. nan_guard pulls NaN/Inf to 0, keeping downstream stable.

Quick Checklist

Test ScenarioStatus
Is the causal mask correct? (s = 0 attends only to itself)✅ k < causal_len check
Is the GQA mapping correct?✅ kv_h = h / q_per_group
Online softmax = standard softmax?⚠ no formal test, but loss curve reasonable in practice
head_dim = 256 edge case crash?✅ MAX_HEAD_DIM = 256
seq_len = 1 (decode-style)✅ skips the block iteration loop
Is backward LSE used correctly?✅ proof in 10_backward_attention.md

Next

06_activation.md — GeLU, SwiGLU. The non-linearity component of the FFN.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.