attention_backward* — Backward Attention (3 Varyant)
Dosya: 10_backward_attention.wgsl Pipeline adımı: Backward'ın en pahalı + en karmaşık kısmı. Toplam step time'ın ~%20'si.
5 kernel:
attention_backward— streaming, seq>1024 için (CAS atomics on dK/dV)attention_backward_short— split-kernel Phase 0-4, dQ owned, P/dS scratch'e flushattention_backward_short_dKdV— split-kernel Phase 5, dK/dV owned, no atomicsrope_q_backward— Q rotation invertrope_k_backward— K rotation invert
Nedir bu ya?
Forward'da attention'ı düşün: her token, geçmişteki bütün token'lara "ne kadar dikkat edeyim?" diye bakıyor, sonra onların değerlerini bu ağırlıklarla karıştırıp çıktısını üretiyor. Backward ise tam tersi bir soru soruyor: çıktı yanlış çıktıysa, suçun/payın kime ne kadar düştüğünü geri dağıtmak. Sınıfta grup ödevi bozuldu, şimdi kimin ne kadar katkısı vardı diye geri sarıp herkesin notunu güncelliyorsun gibi.
Sorun şu: forward'da herkes "kim okuyor" (query) açısından düşünüyordu — her token kendi satırını topluyordu. Ama gradyanı geri yazarken aynı veriyi "kim okunuyor" (key/value) açısından toplamak gerekiyor. Aynı defteri bir kere satırlara göre, bir kere sütunlara göre okumak gibi. İki farklı yön. Bunu tek geçişte yapmaya kalkınca ya bilgiyi bir yere park edip (scratch buffer) sonra ters yönden okuyorsun, ya da herkesin aynı anda aynı satıra yazmaya çalıştığı bir trafik sıkışıklığına düşüyorsun.
İşte 3 varyantın sebebi bu. Eski tek-kernel yaklaşımı, ortak bir kasaya herkesin aynı anda para yatırmaya çalışması gibiydi — kuyruk, kilit, bekleme (atomik yazma). Yeni tasarım işi ikiye bölüyor: bir kernel "okuyan" tarafın gradyanını hesaplayıp ara veriyi bir kenara not ediyor, ikinci kernel o notu ters yönden okuyup "okunan" tarafın gradyanını topluyor. Üçüncü varyant ise dizi çok uzayınca (ara not'lar dev olunca) belleği idare etmek için akış (streaming) moduna düşen yedek plan.
Akıllıca kısım, bu üç yolun da matematiksel olarak aynı zincir kuralını çözmesi ama farklı bellek/hız dengeleri kurması. Aşağıda önce hangi derdin hangi varyantı doğurduğunu, sonra her birinin nasıl çalıştığını adım adım açacağız.
Niye 3 Varyant?
Önceki versiyon: tek kernel, CAS atomics
attention_backward (streaming):
Phase 0-4: recompute P, dP, accumulate dQ
Phase 5: dV += P·dO via CAS atomic
dK += dS·Q via CAS atomicSorun: GQA'da n_q_per_kv = 3-4, 3-4 query head aynı kv head'e CAS atomic ediyordu. seq=512'de erken pozisyonlar için ~1024 paralel atomic write. Contention: tahmini 17.8 ms / step kayıp.
Yeni: split kernel design
attention_backward_short:
Phase 0-4: P[j], dS[j] hesapla, scratch buffer'a flush (no atomic)
Phase 5: dQ_acc[d] hesapla, owned write (no atomic)
attention_backward_short_dKdV:
Phase 5: per (j, kv_h) workgroup
loop s ∈ [j, seq) × h_q ∈ q_group(kv_h)
P[s, j], dS[s, j] scratch'ten oku
dV[j, d] += P · dO (owned write, no atomic)
dK[j, d] += dS · Q (owned write, no atomic)Sonuç: ~17.8 ms / step kazanç.
Streaming kernel hâlâ var — seq_len > 1024 durumlarında scratch buffer çok büyük olur (1024² × 4B = 4MB / head, çok), o yüzden streaming fallback. Production seq=512, split path kullanılır.
Matematiksel Tanım
Forward (özetle): O = softmax(Q·K^T · scale)·V
Backward — chain rule (Aşırı klasik formül):
P = softmax(Q·K^T · scale) (forward already computed, reusable via LSE)
D[s] = Σ_j P[s,j] · dP[s,j] per-row scalar
dP[s,j] = Σ_d dO[s,d] · V[j,d] via dot product
dS[s,j] = scale · P[s,j] · (dP[s,j] - D[s]) gradient of scores
dQ[s,d] = Σ_j dS[s,j] · K[j,d]
dK[j,d] = Σ_s dS[s,j] · Q[s,d] cross-row scatter
dV[j,d] = Σ_s P[s,j] · dO[s,d] cross-row scatterattention_backward_short (Production Path)
Used when seq_len ≤ 1024. Tek-pass, tüm P[seq] ve dS[seq] workgroup memory'de.
Workgroup memory
sh_q: array<f32, 256> // Q row cache
sh_do: array<f32, 256> // dO row cache
sh_p_full: array<f32, 1024> // P[s, *] for all j (causal up to s)
sh_dp_full: array<f32, 1024> // dP[s, *]
sh_partial: array<f32, 256> // dQ accumulator reductionTotal: ~11 KB. Apple GPU 32 KB limit'inde rahat.
Algorithm
1. cache Q[s, h, *] and dO[s, h, *]
2. for each j in [0, s]:
p[j] = exp(scale·Q·K[j] - LSE[s,h])
dp[j] = Σ_d dO[d] · V[j,d]
write to sh_p_full[j], sh_dp_full[j]
3. D = Σ_j P[j] · dP[j]
4. for each j:
dS[j] = scale · P[j] · (dP[j] - D)
5. dQ[d] = Σ_j dS[j] · K[j,d]
write dQ[s,h,d] = previous + dQ[d] (owned, no atomic)
6. flush P[j], dS[j] to scratch buffers (P_scratch, dS_scratch)dK[j,d] ve dV[j,d] bu kernel'da YAZILMAZ. Onları companion kernel handles.
Flush — niye?
attention_backward_short per (s, h) çalışıyor. P[s,j] ve dS[s,j] tüm j için ihtiyaç var → ama dK/dV update'i için per (j, kv_h) organizasyonu lazım. Yani aynı veri iki farklı yönde akmalı:
- s-major:
attention_backward_shortüretiyor - j-major:
attention_backward_short_dKdVtüketiyor
Çözüm: Geçici scratch mega-buffer'lar:
P_scratch: [seq_len, n_heads, seq_len] f32 array
dS_scratch: [seq_len, n_heads, seq_len] f32 arrayP_scratch[s, h, j] = P[s, h, j]. seq=512, n_heads=12 → 512 × 12 × 512 × 4B = 12 MB per scratch. 2 scratch = 24 MB. Activation memory'ye eklenir.
attention_backward_short_dKdV (Companion Path)
Per (j, kv_h) bir workgroup. Belirli bir KV pozisyon ve KV head için tüm relevant query position'ları topla.
Workgroup memory
sh_dk_partial: array<f32, 256>
sh_dv_partial: array<f32, 256>Sadece 2 KB. Çok hafif.
Algorithm
for thread (d in head_dim):
dk_partial = 0
dv_partial = 0
for s in [j, seq): # causal: j ≤ s
for h_q in q_group(kv_h): # all query heads using this kv head
P = P_scratch[s, h_q, j]
dS = dS_scratch[s, h_q, j]
Q = Q[s, h_q, d]
dO = dO[s, h_q, d]
dk_partial += dS · Q
dv_partial += P · dO
dK[j, kv_h, d] += dk_partial (owned)
dV[j, kv_h, d] += dv_partial (owned)No atomic. Her (j, kv_h, d) tek workgroup tarafından dokunuluyor — ownership pattern.
GQA için: q_group(kv_h) = q_per_group query heads share this kv. Bu loop hepsi üzerinden. Pratik: n_heads=12, n_kv=4 → q_per_group=3 → 3 query head iteration.
Niye atomik yok?
(j, kv_h, d) 3-tuple unique workgroup ownership. Aynı (j, kv_h, d) iki workgroup tarafından yazılmaz → race yok → atomic gereksiz.
+= semantic dikkat: bu kernel kendi içinde sadece bir kez her slot'a yazıyor (workgroup başı). Ama host-side gradient accumulation (multiple micro-steps) dK/dV zero-init önce eklenir. Doğru.
attention_backward (Streaming Fallback)
seq_len > 1024 için. P_scratch ve dS_scratch çok büyük olur (1024² × 4B = 4MB per head per scratch — 12 head × 4MB × 2 = 96 MB just for scratch).
Streaming approach: P, dP'i sadece block by block hesapla, recompute (storage tasarrufu). 2-pass:
- Pass A: D hesapla (tüm K-block'lar üzerinden)
- Pass B: D bilindikten sonra dQ accumulate, dK/dV CAS atomic write
Her K-block 256 wide. Workgroup memory: sh_p[256], sh_dp[256] per block — sadece 2KB.
CAS atomic dK/dV — race olduğu için (multiple s positions, same j slot, same kv_h via GQA). Tahmini contention ~1024-way per j slot.
RoPE Backward
// rope_q_backward: backward through Q rotation
// dQ' (rotated) → dQ (un-rotated)
// Forward: q' = R(angle) · q
// Backward: dq = R(-angle) · dq' = R(angle)^T · dq'
let dq0 = dQ[i0];
let dq1 = dQ[i1];
dQ[i0] = dq0 * c + dq1 * sn; (note: + sin instead of - sin)
dQ[i1] = -dq0 * sn + dq1 * c;RoPE rotasyon matrisi orthonormal — R^T = R^-1. Yani backward = forward with negated angle, yani same rotation matrix with -angle:
R(-θ) = [ cos -(-sin)] = [ cos sin]
[-(-sin) cos] [-sin cos]Sin işareti tersine, cos aynı. Klasik 2D rotation invert.
Bind Group ABI
attention_backward_short (11 binding - Production Path)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | Q: array<f32> |
| 1 | storage, read | K: array<f32> |
| 2 | storage, read | V: array<f32> |
| 3 | storage, read | dO: array<f32> |
| 4 | storage, read | LSE: array<f32> (forward'dan saved) |
| 5 | storage, read_write | dQ: array<f32> (owned write) |
| 6 | storage, read_write | P_scratch: array<f32> |
| 7 | storage, read_write | dS_scratch: array<f32> |
| 8 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
| 9 | uniform | params: vec4<f32> — (scale, _, _, _) |
| 10 | storage, read | seg: array<u32> — her sorgunun döküman başlangıç indeksi (cross-doc mask) |
dK ve dV BU KERNEL'DA YAZILMAZ — companion attention_backward_short_dKdV kernel'ı ile hesaplanır.
attention_backward_short_dKdV (7 binding - Companion Path)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | Q: array<f32> |
| 1 | storage, read | dO: array<f32> |
| 2 | storage, read | P_scratch: array<f32> |
| 3 | storage, read | dS_scratch: array<f32> |
| 4 | storage, read_write | dK: array<f32> (owned write, no atomic) |
| 5 | storage, read_write | dV: array<f32> (owned write, no atomic) |
| 6 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
attention_backward (11 binding - Streaming Fallback)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | Q: array<f32> |
| 1 | storage, read | K: array<f32> |
| 2 | storage, read | V: array<f32> |
| 3 | storage, read | dO: array<f32> |
| 4 | storage, read | LSE: array<f32> (forward'dan saved) |
| 5 | storage, read_write | dQ: array<f32> (owned write) |
| 6 | storage, read_write | dK: array<atomic<u32>> (CAS atomic scatter) |
| 7 | storage, read_write | dV: array<atomic<u32>> (CAS atomic scatter) |
| 8 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
| 9 | uniform | params: vec4<f32> — (scale, _, _, _) |
| 10 | storage, read | seg: array<u32> — her sorgunun döküman başlangıç indeksi (cross-doc mask) |
rope_q_backward / rope_k_backward (3 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read_write | dQ veya dK: array<f32> |
| 1 | uniform | dims: vec4<u32> — (seq_len, n_heads/n_kv_heads, head_dim, pos_offset) |
| 2 | uniform | params: vec4<f32> — (rope_base, _, _, _) |
Dispatch Şekli
attention_backward_short
grid: (seq_len, n_heads, 1)
WG=256Aynı attention_forward ile.
attention_backward_short_dKdV
grid: (seq_len, n_kv_heads, 1) ← n_kv_heads, not n_heads!
WG=256n_kv_heads daha az (n_heads=12, n_kv=4) → 3x daha az WG. Ama her WG q_per_group kez fazla iş yapar. Net iş kabaca eşit.
Performance — Profile snapshot'tan
attention_backward_short (×12) 63.18 ms 12.2% of step
attention_backward_short_dKdV (×12) 39.19 ms 7.6% of step
combined 102.37 ms 19.8% of stepBackward'ın %32'si — backward'da en pahalı kernel grubu. (Sonra matmul_at, matmul_t geliyor.)
WGSL-Spesifik Notlar
1. Tint dead-code elimination
Workgroup memory variables tüm kernel'larda görünür ama Tint sadece referenced ones'u allocate ediyor. Yani:
attention_backward_short→ sh_q, sh_do, sh_p_full, sh_dp_full, sh_partial (~11 KB)attention_backward_short_dKdV→ sh_dk_partial, sh_dv_partial (~2 KB)attention_backward→ sh_q, sh_do, sh_p, sh_dp, sh_partial, sh_dq_acc, sh_D (~6 KB)
Aynı var<workgroup> declaration list, farklı occupancy per kernel.
2. Atomik split kazançları
Split design literal CAS atomic'i ortadan kaldırır (companion kernel'da). Apple GPU CAS retry overhead sub-linear ama anlamlı. ~17.8ms total step time savings.
3. Scratch buffer host-side allocation
P_scratch, dS_scratch host-side allocate, lifetime training step boyunca. Memory cost ~24MB per step. Reused her step.
Code Review
Bulgu 1: P, dP recompute (streaming) — duplicate work
| Risk | Açıklama |
|---|---|
| 🟢 trade-off | Streaming kernel P, dP'i 2 kez hesaplar (Pass A + Pass B). seq>1024 için memory tasarrufu (kompakt scratch). Pratik etki ~5% backward overhead. Trade-off doğru. |
Bulgu 2: split path scratch overhead
| Risk | Açıklama |
|---|---|
| 🟡 minor | 24 MB scratch için yer açıyoruz. Pratikte bu activation-tier memory; concern değil. seq=512 için doğru tercih. |
Hızlı Kontrol Listesi
| Test Senaryosu | Durum |
|---|---|
| Forward LSE backward'da reuse mü? | ✅ recompute yok |
dK/dV race-free mı (split path)? | ✅ ownership pattern |
attention_backward streaming hâlâ correct mi? | ⚠ seq>1024 path az test edildi |
| RoPE invert correct mu? | ⚠ identity test yok |
| GQA q_per_group iteration correct mu? | ✅ kernel logic |
Sonraki
11_backward_linear.md — backward matmul'leri: matmul_t, matmul_t_acc, matmul_at, matmul_at_acc. Weight gradient'ları nasıl hesaplanıyor.