attention_backward* — Backward Attention (3 Variants)
File: 10_backward_attention.wgsl Pipeline step: The most expensive + most complex part of backward. Around 20% of total step time.
5 kernels:
attention_backward— streaming, for seq>1024 (CAS atomics on dK/dV)attention_backward_short— split-kernel Phase 0-4, dQ owned, P/dS flushed to scratchattention_backward_short_dKdV— split-kernel Phase 5, dK/dV owned, no atomicsrope_q_backward— Q rotation invertrope_k_backward— K rotation invert
Wait, what is this?
Think back to attention in the forward pass: each token looks at every earlier token and asks "how much should I pay attention to you?", then blends their values by those weights to produce its output. Backward asks the opposite question: if the output came out wrong, how do we hand the blame and credit back to everyone who contributed? It's like a group project that went sideways — you rewind, figure out who pulled what weight, and update each person's grade accordingly.
Here's the catch. In the forward pass everything was organized by the reader (the query) — each token summed up its own row. But to push the gradient back you also need to sum the same data organized by who got read (the keys and values). It's like having to read the same ledger once by rows and once by columns. Two different directions. If you try to do both in one pass, you either park the intermediate data somewhere (a scratch buffer) and read it back the other way, or you fall into a traffic jam where everyone tries to write the same row at once.
That's exactly why there are 3 variants. The old single-kernel approach was like everyone trying to deposit into one shared safe at the same time — queues, locks, waiting (atomic writes). The new design splits the job in two: one kernel computes the reader-side gradient and jots the intermediate data off to the side, and a second kernel reads those notes back the other direction to sum the read-side gradient. The third variant is the fallback that drops into a streaming mode to keep memory in check when the sequence gets long (and those side notes would balloon).
The neat part is that all three paths solve the same chain rule mathematically — they just strike different memory/speed trade-offs. Below we'll walk through which pain point gave rise to which variant, and then how each one works step by step.
Why 3 Variants?
Previous version: single kernel, CAS atomics
attention_backward (streaming):
Phase 0-4: recompute P, dP, accumulate dQ
Phase 5: dV += P·dO via CAS atomic
dK += dS·Q via CAS atomicProblem: In GQA n_q_per_kv = 3-4, so 3-4 query heads were doing CAS atomics on the same kv head. At seq=512, for early positions that meant ~1024 parallel atomic writes. Contention: an estimated 17.8 ms / step loss.
New: split kernel design
attention_backward_short:
Phase 0-4: compute P[j], dS[j], flush to scratch buffer (no atomic)
Phase 5: compute dQ_acc[d], owned write (no atomic)
attention_backward_short_dKdV:
Phase 5: per (j, kv_h) workgroup
loop s ∈ [j, seq) × h_q ∈ q_group(kv_h)
read P[s, j], dS[s, j] from scratch
dV[j, d] += P · dO (owned write, no atomic)
dK[j, d] += dS · Q (owned write, no atomic)Result: ~17.8 ms / step gained.
The streaming kernel still exists — for seq_len > 1024 cases the scratch buffer becomes too large (1024² × 4B = 4MB / head, too much), so the streaming path is a fallback. Production runs seq=512, where the split path is used.
Mathematical Definition
Forward (in brief): O = softmax(Q·K^T · scale)·V
Backward — chain rule (very classic formula):
P = softmax(Q·K^T · scale) (forward already computed, reusable via LSE)
D[s] = Σ_j P[s,j] · dP[s,j] per-row scalar
dP[s,j] = Σ_d dO[s,d] · V[j,d] via dot product
dS[s,j] = scale · P[s,j] · (dP[s,j] - D[s]) gradient of scores
dQ[s,d] = Σ_j dS[s,j] · K[j,d]
dK[j,d] = Σ_s dS[s,j] · Q[s,d] cross-row scatter
dV[j,d] = Σ_s P[s,j] · dO[s,d] cross-row scatterattention_backward_short (Production Path)
Used when seq_len ≤ 1024. Single-pass, all of P[seq] and dS[seq] held in workgroup memory.
Workgroup memory
sh_q: array<f32, 256> // Q row cache
sh_do: array<f32, 256> // dO row cache
sh_p_full: array<f32, 1024> // P[s, *] for all j (causal up to s)
sh_dp_full: array<f32, 1024> // dP[s, *]
sh_partial: array<f32, 256> // dQ accumulator reductionTotal: ~11 KB. Comfortably within the Apple GPU 32 KB limit.
Algorithm
1. cache Q[s, h, *] and dO[s, h, *]
2. for each j in [0, s]:
p[j] = exp(scale·Q·K[j] - LSE[s,h])
dp[j] = Σ_d dO[d] · V[j,d]
write to sh_p_full[j], sh_dp_full[j]
3. D = Σ_j P[j] · dP[j]
4. for each j:
dS[j] = scale · P[j] · (dP[j] - D)
5. dQ[d] = Σ_j dS[j] · K[j,d]
write dQ[s,h,d] = previous + dQ[d] (owned, no atomic)
6. flush P[j], dS[j] to scratch buffers (P_scratch, dS_scratch)dK[j,d] and dV[j,d] are NOT WRITTEN in this kernel. The companion kernel handles them.
Flush — why?
attention_backward_short runs per (s, h). It needs P[s,j] and dS[s,j] for all j → but the dK/dV update needs a per (j, kv_h) organization. So the same data must flow in two different directions:
- s-major: produced by
attention_backward_short - j-major: consumed by
attention_backward_short_dKdV
Solution: Temporary scratch mega-buffers:
P_scratch: [seq_len, n_heads, seq_len] f32 array
dS_scratch: [seq_len, n_heads, seq_len] f32 arrayP_scratch[s, h, j] = P[s, h, j]. seq=512, n_heads=12 → 512 × 12 × 512 × 4B = 12 MB per scratch. 2 scratch = 24 MB. Added to activation memory.
attention_backward_short_dKdV (Companion Path)
One workgroup per (j, kv_h). For a given KV position and KV head, gather all relevant query positions.
Workgroup memory
sh_dk_partial: array<f32, 256>
sh_dv_partial: array<f32, 256>Only 2 KB. Very lightweight.
Algorithm
for thread (d in head_dim):
dk_partial = 0
dv_partial = 0
for s in [j, seq): # causal: j ≤ s
for h_q in q_group(kv_h): # all query heads using this kv head
P = P_scratch[s, h_q, j]
dS = dS_scratch[s, h_q, j]
Q = Q[s, h_q, d]
dO = dO[s, h_q, d]
dk_partial += dS · Q
dv_partial += P · dO
dK[j, kv_h, d] += dk_partial (owned)
dV[j, kv_h, d] += dv_partial (owned)No atomic. Each (j, kv_h, d) is touched by a single workgroup — ownership pattern.
For GQA: q_group(kv_h) = q_per_group query heads share this kv. This loop iterates over all of them. In practice: n_heads=12, n_kv=4 → q_per_group=3 → 3 query head iterations.
Why no atomics?
The 3-tuple (j, kv_h, d) has unique workgroup ownership. The same (j, kv_h, d) is never written by two workgroups → no race → atomic unnecessary.
Note the += semantics: within this kernel each slot is written only once (per workgroup). But host-side gradient accumulation (multiple micro-steps) zero-inits dK/dV before adding. Correct.
attention_backward (Streaming Fallback)
For seq_len > 1024. P_scratch and dS_scratch become too large (1024² × 4B = 4MB per head per scratch — 12 heads × 4MB × 2 = 96 MB just for scratch).
Streaming approach: compute P, dP only block by block, recomputing (storage savings). 2-pass:
- Pass A: compute D (over all K-blocks)
- Pass B: once D is known, accumulate dQ, dK/dV CAS atomic write
Each K-block is 256 wide. Workgroup memory: sh_p[256], sh_dp[256] per block — only 2KB.
CAS atomic dK/dV — because there is a race (multiple s positions, same j slot, same kv_h via GQA). Estimated contention ~1024-way per j slot.
RoPE Backward
// rope_q_backward: backward through Q rotation
// dQ' (rotated) → dQ (un-rotated)
// Forward: q' = R(angle) · q
// Backward: dq = R(-angle) · dq' = R(angle)^T · dq'
let dq0 = dQ[i0];
let dq1 = dQ[i1];
dQ[i0] = dq0 * c + dq1 * sn; (note: + sin instead of - sin)
dQ[i1] = -dq0 * sn + dq1 * c;The RoPE rotation matrix is orthonormal — R^T = R^-1. So backward = forward with negated angle, i.e. the same rotation matrix with -angle:
R(-θ) = [ cos -(-sin)] = [ cos sin]
[-(-sin) cos] [-sin cos]The sin sign flips, the cos stays the same. The classic 2D rotation invert.
Bind Group ABI
attention_backward_short (11 binding - Production Path)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | Q: array<f32> |
| 1 | storage, read | K: array<f32> |
| 2 | storage, read | V: array<f32> |
| 3 | storage, read | dO: array<f32> |
| 4 | storage, read | LSE: array<f32> (saved from forward) |
| 5 | storage, read_write | dQ: array<f32> (owned write) |
| 6 | storage, read_write | P_scratch: array<f32> |
| 7 | storage, read_write | dS_scratch: array<f32> |
| 8 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
| 9 | uniform | params: vec4<f32> — (scale, _, _, _) |
| 10 | storage, read | seg: array<u32> — document start index for each query (cross-doc mask) |
dK and dV are NOT WRITTEN IN THIS KERNEL — they are computed by the companion attention_backward_short_dKdV kernel.
attention_backward_short_dKdV (7 binding - Companion Path)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | Q: array<f32> |
| 1 | storage, read | dO: array<f32> |
| 2 | storage, read | P_scratch: array<f32> |
| 3 | storage, read | dS_scratch: array<f32> |
| 4 | storage, read_write | dK: array<f32> (owned write, no atomic) |
| 5 | storage, read_write | dV: array<f32> (owned write, no atomic) |
| 6 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
attention_backward (11 binding - Streaming Fallback)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | Q: array<f32> |
| 1 | storage, read | K: array<f32> |
| 2 | storage, read | V: array<f32> |
| 3 | storage, read | dO: array<f32> |
| 4 | storage, read | LSE: array<f32> (saved from forward) |
| 5 | storage, read_write | dQ: array<f32> (owned write) |
| 6 | storage, read_write | dK: array<atomic<u32>> (CAS atomic scatter) |
| 7 | storage, read_write | dV: array<atomic<u32>> (CAS atomic scatter) |
| 8 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
| 9 | uniform | params: vec4<f32> — (scale, _, _, _) |
| 10 | storage, read | seg: array<u32> — document start index for each query (cross-doc mask) |
rope_q_backward / rope_k_backward (3 binding)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | dQ or dK: array<f32> |
| 1 | uniform | dims: vec4<u32> — (seq_len, n_heads/n_kv_heads, head_dim, pos_offset) |
| 2 | uniform | params: vec4<f32> — (rope_base, _, _, _) |
Dispatch Shape
attention_backward_short
grid: (seq_len, n_heads, 1)
WG=256Same as attention_forward.
attention_backward_short_dKdV
grid: (seq_len, n_kv_heads, 1) ← n_kv_heads, not n_heads!
WG=256n_kv_heads is smaller (n_heads=12, n_kv=4) → 3x fewer WGs. But each WG does q_per_group times more work. Net work roughly equal.
Performance — from a profile snapshot
attention_backward_short (×12) 63.18 ms 12.2% of step
attention_backward_short_dKdV (×12) 39.19 ms 7.6% of step
combined 102.37 ms 19.8% of step32% of backward — the most expensive kernel group in backward. (Then come matmul_at, matmul_t.)
WGSL-Specific Notes
1. Tint dead-code elimination
Workgroup memory variables appear in all kernels, but Tint only allocates the referenced ones. That is:
attention_backward_short→ sh_q, sh_do, sh_p_full, sh_dp_full, sh_partial (~11 KB)attention_backward_short_dKdV→ sh_dk_partial, sh_dv_partial (~2 KB)attention_backward→ sh_q, sh_do, sh_p, sh_dp, sh_partial, sh_dq_acc, sh_D (~6 KB)
Same var<workgroup> declaration list, different occupancy per kernel.
2. Atomic split gains
The split design literally eliminates the CAS atomic (in the companion kernel). Apple GPU CAS retry overhead is sub-linear but meaningful. ~17.8ms total step time savings.
3. Scratch buffer host-side allocation
P_scratch, dS_scratch are allocated host-side, with a lifetime spanning the training step. Memory cost ~24MB per step. Reused each step.
Code Review
Finding 1: P, dP recompute (streaming) — duplicate work
| Risk | Description |
|---|---|
| 🟢 trade-off | The streaming kernel computes P, dP twice (Pass A + Pass B). For seq>1024 this saves memory (compact scratch). Practical impact ~5% backward overhead. The trade-off is correct. |
Finding 2: split path scratch overhead
| Risk | Description |
|---|---|
| 🟡 minor | We make room for 24 MB of scratch. In practice this is activation-tier memory; not a concern. The right choice for seq=512. |
Quick Checklist
| Test Scenario | Status |
|---|---|
| Is forward LSE reused in backward? | ✅ no recompute |
Is dK/dV race-free (split path)? | ✅ ownership pattern |
Is attention_backward streaming still correct? | ⚠ seq>1024 path lightly tested |
| Is RoPE invert correct? | ⚠ no identity test |
| Is GQA q_per_group iteration correct? | ✅ kernel logic |
Next
11_backward_linear.md — the backward matmuls: matmul_t, matmul_t_acc, matmul_at, matmul_at_acc. How the weight gradients are computed.