llm.istanbul·Study
TR EN
Workbench →

attention_backward* — Backward Attention (3 Variants)

File: 10_backward_attention.wgsl Pipeline step: The most expensive + most complex part of backward. Around 20% of total step time.

5 kernels:

  • attention_backward — streaming, for seq>1024 (CAS atomics on dK/dV)
  • attention_backward_short — split-kernel Phase 0-4, dQ owned, P/dS flushed to scratch
  • attention_backward_short_dKdV — split-kernel Phase 5, dK/dV owned, no atomics
  • rope_q_backward — Q rotation invert
  • rope_k_backward — K rotation invert

Wait, what is this?

Think back to attention in the forward pass: each token looks at every earlier token and asks "how much should I pay attention to you?", then blends their values by those weights to produce its output. Backward asks the opposite question: if the output came out wrong, how do we hand the blame and credit back to everyone who contributed? It's like a group project that went sideways — you rewind, figure out who pulled what weight, and update each person's grade accordingly.

Here's the catch. In the forward pass everything was organized by the reader (the query) — each token summed up its own row. But to push the gradient back you also need to sum the same data organized by who got read (the keys and values). It's like having to read the same ledger once by rows and once by columns. Two different directions. If you try to do both in one pass, you either park the intermediate data somewhere (a scratch buffer) and read it back the other way, or you fall into a traffic jam where everyone tries to write the same row at once.

That's exactly why there are 3 variants. The old single-kernel approach was like everyone trying to deposit into one shared safe at the same time — queues, locks, waiting (atomic writes). The new design splits the job in two: one kernel computes the reader-side gradient and jots the intermediate data off to the side, and a second kernel reads those notes back the other direction to sum the read-side gradient. The third variant is the fallback that drops into a streaming mode to keep memory in check when the sequence gets long (and those side notes would balloon).

The neat part is that all three paths solve the same chain rule mathematically — they just strike different memory/speed trade-offs. Below we'll walk through which pain point gave rise to which variant, and then how each one works step by step.


Why 3 Variants?

Previous version: single kernel, CAS atomics

attention_backward (streaming):
  Phase 0-4: recompute P, dP, accumulate dQ
  Phase 5:   dV += P·dO    via CAS atomic
             dK += dS·Q    via CAS atomic

Problem: In GQA n_q_per_kv = 3-4, so 3-4 query heads were doing CAS atomics on the same kv head. At seq=512, for early positions that meant ~1024 parallel atomic writes. Contention: an estimated 17.8 ms / step loss.

New: split kernel design

attention_backward_short:
  Phase 0-4: compute P[j], dS[j], flush to scratch buffer (no atomic)
  Phase 5:   compute dQ_acc[d], owned write (no atomic)

attention_backward_short_dKdV:
  Phase 5: per (j, kv_h) workgroup
           loop s ∈ [j, seq) × h_q ∈ q_group(kv_h)
           read P[s, j], dS[s, j] from scratch
           dV[j, d] += P · dO  (owned write, no atomic)
           dK[j, d] += dS · Q  (owned write, no atomic)

Result: ~17.8 ms / step gained.

The streaming kernel still exists — for seq_len > 1024 cases the scratch buffer becomes too large (1024² × 4B = 4MB / head, too much), so the streaming path is a fallback. Production runs seq=512, where the split path is used.


Mathematical Definition

Forward (in brief): O = softmax(Q·K^T · scale)·V

Backward — chain rule (very classic formula):

P = softmax(Q·K^T · scale)         (forward already computed, reusable via LSE)

D[s] = Σ_j P[s,j] · dP[s,j]        per-row scalar
dP[s,j] = Σ_d dO[s,d] · V[j,d]    via dot product
dS[s,j] = scale · P[s,j] · (dP[s,j] - D[s])   gradient of scores
dQ[s,d] = Σ_j dS[s,j] · K[j,d]
dK[j,d] = Σ_s dS[s,j] · Q[s,d]    cross-row scatter
dV[j,d] = Σ_s P[s,j] · dO[s,d]    cross-row scatter

attention_backward_short (Production Path)

Used when seq_len ≤ 1024. Single-pass, all of P[seq] and dS[seq] held in workgroup memory.

Workgroup memory

wgsl
sh_q:        array<f32, 256>     // Q row cache
sh_do:       array<f32, 256>     // dO row cache
sh_p_full:   array<f32, 1024>    // P[s, *] for all j (causal up to s)
sh_dp_full:  array<f32, 1024>    // dP[s, *]
sh_partial:  array<f32, 256>     // dQ accumulator reduction

Total: ~11 KB. Comfortably within the Apple GPU 32 KB limit.

Algorithm

1. cache Q[s, h, *] and dO[s, h, *]
2. for each j in [0, s]:
     p[j] = exp(scale·Q·K[j] - LSE[s,h])
     dp[j] = Σ_d dO[d] · V[j,d]
     write to sh_p_full[j], sh_dp_full[j]
3. D = Σ_j P[j] · dP[j]
4. for each j:
     dS[j] = scale · P[j] · (dP[j] - D)
5. dQ[d] = Σ_j dS[j] · K[j,d]
   write dQ[s,h,d] = previous + dQ[d]   (owned, no atomic)
6. flush P[j], dS[j] to scratch buffers (P_scratch, dS_scratch)

dK[j,d] and dV[j,d] are NOT WRITTEN in this kernel. The companion kernel handles them.

Flush — why?

attention_backward_short runs per (s, h). It needs P[s,j] and dS[s,j] for all j → but the dK/dV update needs a per (j, kv_h) organization. So the same data must flow in two different directions:

  • s-major: produced by attention_backward_short
  • j-major: consumed by attention_backward_short_dKdV

Solution: Temporary scratch mega-buffers:

P_scratch:  [seq_len, n_heads, seq_len]    f32 array
dS_scratch: [seq_len, n_heads, seq_len]    f32 array

P_scratch[s, h, j] = P[s, h, j]. seq=512, n_heads=12 → 512 × 12 × 512 × 4B = 12 MB per scratch. 2 scratch = 24 MB. Added to activation memory.


attention_backward_short_dKdV (Companion Path)

One workgroup per (j, kv_h). For a given KV position and KV head, gather all relevant query positions.

Workgroup memory

wgsl
sh_dk_partial: array<f32, 256>
sh_dv_partial: array<f32, 256>

Only 2 KB. Very lightweight.

Algorithm

for thread (d in head_dim):
  dk_partial = 0
  dv_partial = 0
  for s in [j, seq):                            # causal: j ≤ s
    for h_q in q_group(kv_h):                   # all query heads using this kv head
      P  = P_scratch[s, h_q, j]
      dS = dS_scratch[s, h_q, j]
      Q  = Q[s, h_q, d]
      dO = dO[s, h_q, d]
      dk_partial += dS · Q
      dv_partial += P · dO
  dK[j, kv_h, d] += dk_partial    (owned)
  dV[j, kv_h, d] += dv_partial    (owned)

No atomic. Each (j, kv_h, d) is touched by a single workgroup — ownership pattern.

For GQA: q_group(kv_h) = q_per_group query heads share this kv. This loop iterates over all of them. In practice: n_heads=12, n_kv=4 → q_per_group=3 → 3 query head iterations.

Why no atomics?

The 3-tuple (j, kv_h, d) has unique workgroup ownership. The same (j, kv_h, d) is never written by two workgroups → no race → atomic unnecessary.

Note the += semantics: within this kernel each slot is written only once (per workgroup). But host-side gradient accumulation (multiple micro-steps) zero-inits dK/dV before adding. Correct.


attention_backward (Streaming Fallback)

For seq_len > 1024. P_scratch and dS_scratch become too large (1024² × 4B = 4MB per head per scratch — 12 heads × 4MB × 2 = 96 MB just for scratch).

Streaming approach: compute P, dP only block by block, recomputing (storage savings). 2-pass:

  • Pass A: compute D (over all K-blocks)
  • Pass B: once D is known, accumulate dQ, dK/dV CAS atomic write

Each K-block is 256 wide. Workgroup memory: sh_p[256], sh_dp[256] per block — only 2KB.

CAS atomic dK/dV — because there is a race (multiple s positions, same j slot, same kv_h via GQA). Estimated contention ~1024-way per j slot.


RoPE Backward

wgsl
// rope_q_backward: backward through Q rotation
// dQ' (rotated) → dQ (un-rotated)
// Forward: q' = R(angle) · q
// Backward: dq = R(-angle) · dq' = R(angle)^T · dq'

let dq0 = dQ[i0];
let dq1 = dQ[i1];
dQ[i0] = dq0 * c + dq1 * sn;       (note: + sin instead of - sin)
dQ[i1] = -dq0 * sn + dq1 * c;

The RoPE rotation matrix is orthonormal — R^T = R^-1. So backward = forward with negated angle, i.e. the same rotation matrix with -angle:

R(-θ) = [ cos -(-sin)] = [ cos  sin]
        [-(-sin)  cos]   [-sin  cos]

The sin sign flips, the cos stays the same. The classic 2D rotation invert.


Bind Group ABI

attention_backward_short (11 binding - Production Path)

BindingTypeDetail
0storage, readQ: array<f32>
1storage, readK: array<f32>
2storage, readV: array<f32>
3storage, readdO: array<f32>
4storage, readLSE: array<f32> (saved from forward)
5storage, read_writedQ: array<f32> (owned write)
6storage, read_writeP_scratch: array<f32>
7storage, read_writedS_scratch: array<f32>
8uniformdims: vec4<u32>(seq_len, n_heads, n_kv_heads, head_dim)
9uniformparams: vec4<f32>(scale, _, _, _)
10storage, readseg: array<u32> — document start index for each query (cross-doc mask)

dK and dV are NOT WRITTEN IN THIS KERNEL — they are computed by the companion attention_backward_short_dKdV kernel.

attention_backward_short_dKdV (7 binding - Companion Path)

BindingTypeDetail
0storage, readQ: array<f32>
1storage, readdO: array<f32>
2storage, readP_scratch: array<f32>
3storage, readdS_scratch: array<f32>
4storage, read_writedK: array<f32> (owned write, no atomic)
5storage, read_writedV: array<f32> (owned write, no atomic)
6uniformdims: vec4<u32>(seq_len, n_heads, n_kv_heads, head_dim)

attention_backward (11 binding - Streaming Fallback)

BindingTypeDetail
0storage, readQ: array<f32>
1storage, readK: array<f32>
2storage, readV: array<f32>
3storage, readdO: array<f32>
4storage, readLSE: array<f32> (saved from forward)
5storage, read_writedQ: array<f32> (owned write)
6storage, read_writedK: array<atomic<u32>> (CAS atomic scatter)
7storage, read_writedV: array<atomic<u32>> (CAS atomic scatter)
8uniformdims: vec4<u32>(seq_len, n_heads, n_kv_heads, head_dim)
9uniformparams: vec4<f32>(scale, _, _, _)
10storage, readseg: array<u32> — document start index for each query (cross-doc mask)

rope_q_backward / rope_k_backward (3 binding)

BindingTypeDetail
0storage, read_writedQ or dK: array<f32>
1uniformdims: vec4<u32>(seq_len, n_heads/n_kv_heads, head_dim, pos_offset)
2uniformparams: vec4<f32>(rope_base, _, _, _)

Dispatch Shape

attention_backward_short

grid: (seq_len, n_heads, 1)
WG=256

Same as attention_forward.

attention_backward_short_dKdV

grid: (seq_len, n_kv_heads, 1)        ← n_kv_heads, not n_heads!
WG=256

n_kv_heads is smaller (n_heads=12, n_kv=4) → 3x fewer WGs. But each WG does q_per_group times more work. Net work roughly equal.


Performance — from a profile snapshot

attention_backward_short    (×12)  63.18 ms   12.2% of step
attention_backward_short_dKdV (×12) 39.19 ms  7.6% of step
                          combined  102.37 ms  19.8% of step

32% of backward — the most expensive kernel group in backward. (Then come matmul_at, matmul_t.)


WGSL-Specific Notes

1. Tint dead-code elimination

Workgroup memory variables appear in all kernels, but Tint only allocates the referenced ones. That is:

  • attention_backward_short → sh_q, sh_do, sh_p_full, sh_dp_full, sh_partial (~11 KB)
  • attention_backward_short_dKdV → sh_dk_partial, sh_dv_partial (~2 KB)
  • attention_backward → sh_q, sh_do, sh_p, sh_dp, sh_partial, sh_dq_acc, sh_D (~6 KB)

Same var<workgroup> declaration list, different occupancy per kernel.

2. Atomic split gains

The split design literally eliminates the CAS atomic (in the companion kernel). Apple GPU CAS retry overhead is sub-linear but meaningful. ~17.8ms total step time savings.

3. Scratch buffer host-side allocation

P_scratch, dS_scratch are allocated host-side, with a lifetime spanning the training step. Memory cost ~24MB per step. Reused each step.


Code Review

Finding 1: P, dP recompute (streaming) — duplicate work

RiskDescription
🟢 trade-offThe streaming kernel computes P, dP twice (Pass A + Pass B). For seq>1024 this saves memory (compact scratch). Practical impact ~5% backward overhead. The trade-off is correct.

Finding 2: split path scratch overhead

RiskDescription
🟡 minorWe make room for 24 MB of scratch. In practice this is activation-tier memory; not a concern. The right choice for seq=512.

Quick Checklist

Test ScenarioStatus
Is forward LSE reused in backward?✅ no recompute
Is dK/dV race-free (split path)?✅ ownership pattern
Is attention_backward streaming still correct?⚠ seq>1024 path lightly tested
Is RoPE invert correct?⚠ no identity test
Is GQA q_per_group iteration correct?✅ kernel logic

Next

11_backward_linear.md — the backward matmuls: matmul_t, matmul_t_acc, matmul_at, matmul_at_acc. How the weight gradients are computed.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.