attention_forward — Causal Multi-Head Attention with Online Softmax
Dosya: 05_attention.wgsl Pipeline adımı: Layer'ın kalbi. RoPE'd Q ve K + V → contextual mixing.
İki kernel:
attention_forward— training (full sequence)attention_decode— inference (single query against KV cache)
Nedir bu ya?
Bir grup sohbetine geç kaldın ve 200 mesaj birikmiş. Şu anki mesajı anlamak için geri kaydırıp eski mesajları okuyorsun. Ama hepsini eşit önemde okumuyorsun: birkaç tanesi tam konuyla ilgili, gerisi muhabbet. Kafanda her eski mesaja sessizce bir ağırlık veriyorsun — "bu çok ilgili", "bunu boşver" — sonra o ağırlıklı karışımdan şu anki mesajın ne anlama geldiğini çıkarıyorsun. Attention'ın yaptığı tam bu.
Her token (kelime parçası) bir "sorgu" üretiyor: "ben kimi dinlemeliyim?" Diğer tokenlar da birer "anahtar" sunuyor: "ben şu konuya bakıyorsan ilgiliyim." Sorgu ile anahtarların eşleşmesi (dot product) bir skor veriyor, softmax bu skorları toplamı 1 olan ağırlıklara çeviriyor, ve her tokenın taşıdığı "değer" (value) bu ağırlıklarla harmanlanıp çıktıyı oluşturuyor. Yani fuzzy bir lookup: hash map'te tek bir anahtarla tek satır çekmek yerine, tüm satırları benzerliklerine göre ağırlıklı topluyorsun.
Bir kural var: causal. Token sadece kendinden önceki tokenlara bakabilir, sonrakilere değil — grup sohbetinde henüz yazılmamış mesajı okuyamazsın ya, onun gibi. İleriye bakmak yasak, çünkü model bir sonraki kelimeyi tahmin etmeyi öğreniyor; cevabı görmesine izin veremeyiz.
İşin zor kısmı, bu ağırlıkları (softmax) hesaplamak için normalde tüm skorları belleğe yazıp üç kez taramak gerekir — sekans uzadıkça bu pahalılaşır. Online softmax ise mesajları tek geçişte, akarken okur ve çalışırken sürekli kendini düzeltir. Aşağıda bunun GPU'da nasıl tek pass'e sığdırıldığına bakacağız.
Ne Yapar?
Causal scaled dot-product attention with GQA (Grouped-Query Attention) and online softmax:
score[s, k] = (Q[s] · K[k]) / sqrt(head_dim)
P[s, k] = softmax_k(score[s, k]) (k ≤ s, causal)
O[s] = Σ_k P[s, k] · V[k]Modelin contextual mixing mekanizması. Token s her kendinden önceki token k'ya "ne kadar dikkat edeceğini" hesaplar (P[s,k] olasılığı), sonra V'lerin ağırlıklı toplamı ile output üretir.
Online Softmax — Niye?
Klasik softmax 3 pass:
m = max(scores)— scan all scoresexp_sum = Σ exp(score - m)— scan againP = exp(score - m) / exp_sum— scan again
Sequence uzunsa (örn 1024) 3× memory traffic + scratch buffer for scores. Online softmax bunu single-pass'e çevirir, scratch buffer O(WG + head_dim) olur, O(seq_len) değil.
GQA (Grouped Query Attention)
n_heads query, ama sadece n_kv_heads < n_heads key/value. Multiple Q heads, aynı K/V grubunu paylaşır:
n_heads = 12, n_kv_heads = 4
→ q_per_group = 3
→ Query head h uses kv_head h/3Avantaj: KV cache 1/q_per_group boyutunda. Bizim modelde n_heads=12 / n_kv=4 → KV ⅓ boyutu. Az kalite kaybı, büyük memory tasarrufu.
Online Softmax Algoritması — Çok Önemli
Sequence içindeki K-block'larını sırayla işleyip state'i incremental günceller:
state: m (running max), l (running exp-sum), acc[d] (running V·P)
init: m = -∞, l = 0, acc[d] = 0
for each block of K (size WG=256):
1. compute scores = (Q · K_block[k]) * scale for each k in block
2. block_max = max(scores in block)
3. m_new = max(m, block_max)
4. P[k] = exp(score[k] - m_new)
5. block_sum = Σ P[k]
6. correction = exp(m - m_new) ← old state'i scale
7. l ← l * correction + block_sum
8. acc[d] ← acc[d] * correction + Σ_k P[k] · V[k, d]
9. m ← m_new
final:
O[d] = acc[d] / l
LSE = m + log(l) ← saved for backwardcorrection = exp(m_old - m_new) term, eğer m_new > m_old ise eski state'i m_new'e göre renormalize ediyor. Bu sayede stable, single-pass.
LSE = log(Σ exp(scores)) saved for backward — backward'da bunu reuse edip recompute'tan kaçınıyoruz.
Bind Group ABI
attention_forward (8 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | Q: array<f32> — [seq × n_heads × head_dim] |
| 1 | storage, read | K: array<f32> — [seq × n_kv × head_dim] (GQA: smaller) |
| 2 | storage, read | V: array<f32> — [seq × n_kv × head_dim] |
| 3 | storage, read_write | O: array<f32> — [seq × n_heads × head_dim] |
| 4 | storage, read_write | LSE: array<f32> — [seq × n_heads] saved for backward |
| 5 | uniform | dims: vec4<u32> — (seq_len, n_heads, n_kv_heads, head_dim) |
| 6 | uniform | params: vec4<f32> — (scale, _, _, _) |
| 7 | storage, read | seg: array<u32> — her sorgunun döküman başlangıç indeksi (cross-doc mask) |
scale = 1 / sqrt(head_dim) host'ta hesaplanır.
attention_decode (6 binding)
Decode için Q tek bir row ([n_heads × head_dim]), K/V cache'ten okunur, LSE save yoktur.
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | Q_one: array<f32> — [n_heads × head_dim] (tek satır) |
| 1 | storage, read | K_cache: array<f32> — [seq × n_kv × head_dim] (geçmiş anahtarlar cache'i) |
| 2 | storage, read | V_cache: array<f32> — [seq × n_kv × head_dim] (geçmiş değerler cache'i) |
| 3 | storage, read_write | O_one: array<f32> — [n_heads × head_dim] (tek çıkış satırı) |
| 4 | uniform | dims_dec: vec4<u32> — (cache_len, n_heads, n_kv_heads, head_dim) |
| 5 | uniform | params_dec: vec4<f32> — (scale, _, _, _) |
Workgroup Memory Layout
var<workgroup> sh_q: array<f32, 256>; // Q row cache
var<workgroup> sh_p: array<f32, 256>; // current K-block scores → P values
var<workgroup> sh_acc: array<f32, 256>; // running output accumulator
var<workgroup> sh_partial: array<f32, 256>; // V·P partial sum reduction scratch
var<workgroup> sh_max: array<f32, 1>; // running max m (single value)
var<workgroup> sh_lse: array<f32, 1>; // running sum-exp l (single value)Total: ~4 KB workgroup memory. Apple GPU 32 KB limit'inin %12'si — rahat occupancy.
MAX_HEAD_DIM = 256 constraint: head_dim 256'yı aşamaz. Bizim modelde head_dim=64 (12 head × 64 = 768 d_model), bol bol yer.
Dispatch Şekli
workgroup_size: 256
grid: (seq_len, n_heads, 1) workgroupsBir workgroup = bir (s, h) çifti, yani bir query position bir head için.
Örnek (seq=512, n_heads=12):
- 6144 WG × 256 thread = 1.57M thread
- Each WG iterates
ceil(causal_len / 256)blocks of K
Satır Satır — attention_forward
1) Decode workgroup ID
let s = wgid.x; // query position
let h = wgid.y; // attention head
// ...
let q_per_group = n_heads / n_kv_heads;
let kv_h = h / q_per_group; // GQA: query head h uses kv_hGQA mapping. h=0,1,2 → kv_h=0; h=3,4,5 → kv_h=1; vs.
2) Phase 0 — Cache Q, init state
let q_row_base = s * d_model + h * head_dim;
let causal_len = s + 1u; // attend over [0, s]
let doc_start = seg[s]; // bu sorgunun ait olduğu dökümanın başlangıç indeksicausal_len = s + 1u→ bu query position kendi dahil önceki tüm pozisyonlara dikkat edebilir, sonrakilere edemez.doc_start = seg[s]→ Çapraz döküman maskelemesi: sorgusyalnızca kendi döküman sınırları içindekikanahtarlarına dikkat edebilir. Döküman başlangıcından önceki (k < doc_start) anahtarlar maskelenir.sh_q[i] = Q[...]— Q row'unun head_dim boyutunda copy'sini workgroup memory'ye al. Inner-loop'ta her thread aynı Q'ya erişecek; once-load, many-read.
3) K-block parallelization parameters
let k_par = max(1u, WG / max(1u, head_dim)); // 4 when head_dim=64
let n_active = head_dim * k_par; // 256 when head=64,k_par=4
let d_local = tid / k_par; // d ∈ [0, head_dim)
let kc = tid % k_par; // chunk index ∈ [0, k_par)
let chunk_size = (WG + k_par - 1u) / k_par; // 64Phase D (V·P partial)'de WG'i 2D olarak yorumla: (d_local, kc) mesh'i. d_local çıkış dimension'ı (0..head_dim-1), kc chunk index (0..k_par-1).
Eğer head_dim < WG, multiple thread'ler aynı d için farklı chunk'ları paralel hesaplar.
4) K-block iteration (Döküman Öncesini Atlama)
let first_blk = doc_start / WG;
let n_blocks = (causal_len + WG - 1u) / WG;
for (var blk = first_blk; blk < n_blocks; blk = blk + 1u) {
let k_start = blk * WG;
let k = k_start + tid;first_blk = doc_start / WG→ Döküman başlangıcının altındaki K-block'ları tamamen atlanır! Bu sayede hem bellek okumalarından tasarruf edilir hem de ilk işlenen blokta en az bir adet geçerli (maskelenmemiş) anahtar bulunması garanti edilir. Bu garanti, online softmax hesaplamalarındaki running-max değerinin-∞kalmasını ve running-max düzeltmesindeexp(-∞ - -∞) = NaNhatası üretilmesini (LSE stabilizasyonu) engeller.
5) Phase A — Score computation (Maskelenmiş Konumlar)
var score: f32 = NEG_INF;
if (k >= doc_start && k < causal_len) {
let k_row_base = k * kv_dim + kv_h * head_dim;
var dot: f32 = 0.0;
for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
dot = fma(sh_q[d], K[k_row_base + d], dot);
}
score = dot * scale;
}Her thread bir k için score hesaplar:
- Q (workgroup mem'den) ⋅ K[k] (global mem'den)
head_dimadet FMA- Scale =
1/sqrt(head_dim)(LLaMA-style) if (k >= doc_start && k < causal_len)ile hem döküman maskelemesi hem de causal mask tek satırda uygulanır. Koşulu sağlamayan konumlarNEG_INFalır, böylece softmax'taexp(NEG_INF) = 0katkı üretirler.
Memory pattern: Adjacent thread'ler adjacent K rows okuyor (k=0,1,2,...) — K[k * kv_dim + kv_h * head_dim + d]. Aynı d için adjacent k → scattered (kv_dim apart). Coalesced değil. Bu attention forward'ın performance handicap'i — flash attention forward (perf-doc #1) bunu fix'leyebilir.
6) Phase B — Block max + running max
let block_max = wg_reduce_max(tid, score);
let m_old = sh_max[0];
let m_new = max(m_old, block_max);wg_reduce_max tüm 256 thread'in max'ını döndürür — her thread'de aynı değer.
7) Phase C — P[k] + block sum + correction
var p_val: f32 = 0.0;
if (k < causal_len) {
p_val = exp(score - m_new);
}
sh_p[tid] = p_val;
let block_sum = wg_reduce_sum(tid, p_val);
let correction = exp(m_old - m_new);
let l_old = sh_lse[0];
let l_new = l_old * correction + block_sum;p_val = exp(score - m_new) — softened score. m_new şu ana kadar görülen max, exp(score - max) ≤ 1 numerical stability için.
correction = exp(m_old - m_new):
- İlk iteration:
m_old = -∞,correction = exp(-∞ - m_new) → 0(eski state'i sıfırla, çünkü o zaman bilgi yoktu) - Sonraki iterations: m artarsa
correction < 1, eski accumulators down-scale
8) Phase D — V·P contribution
if (tid < n_active) {
let kl_start = kc * chunk_size;
let kl_end_raw = kl_start + chunk_size;
let kl_end = select(kl_end_raw, WG, kl_end_raw > WG);
var partial: f32 = 0.0;
for (var kl = kl_start; kl < kl_end; kl = kl + 1u) {
let kk = k_start + kl;
if (kk < causal_len) {
partial = fma(sh_p[kl], V[kk * kv_dim + kv_h * head_dim + d_local], partial);
}
}
sh_partial[d_local * k_par + kc] = partial;
}
workgroupBarrier();
if (kc == 0u && tid < n_active) {
var sum_pv: f32 = sh_partial[d_local * k_par];
for (var i = 1u; i < k_par; i = i + 1u) {
sum_pv = sum_pv + sh_partial[d_local * k_par + i];
}
sh_acc[d_local] = sh_acc[d_local] * correction + sum_pv;
}Phase D iki adım:
- Her
(d_local, kc)thread'i kendi chunk'ındakiΣ P[kl] · V[k_start+kl, d_local]partial'ı hesaplar - Barrier sonrası,
kc == 0uthread'leri kendid_localiçin partial'ları toplar (k_par adet) →sh_acc[d_local]'a yaz
Niye iki phase? Doğrudan tek thread'te Σ_kl yapsak head_dim thread çalışıp diğerleri boşta dururdu. Multi-thread reduction: 4 (k_par) thread paralel partial → 1 thread merge. Daha hızlı.
9) Update m, l
if (tid == 0u) {
sh_max[0] = m_new;
sh_lse[0] = l_new;
}
workgroupBarrier();Sadece tid 0 yazar. Single writer, race yok. Barrier sonrası tüm thread'ler güncellenmiş m/l görür.
10) Final O ve LSE
let l_final = max(sh_lse[0], 1e-30);
let inv_l = 1.0 / l_final;
if (tid < head_dim) {
O[s * d_model + h * head_dim + tid] = nan_guard(sh_acc[tid] * inv_l);
}
if (tid == 0u) {
LSE[s * n_heads + h] = sh_max[0] + log(l_final);
}max(l, 1e-30) — division-by-zero koruması. Eğer hiçbir K hit ettiyse l=0 kalabilir (örneğin causal_len=0 edge case, s=0 dahil değil). Pratikte tetiklenmez.
O[d] = acc[d] / l — softmax denominator division.
LSE = m + log(l) — backward'da reuse edilecek, log-sum-exp.
attention_decode Farkı
Inference için. Q tek row (mevcut token), K/V cache'ten okunur (önceki tüm token'ların):
let h = wgid.x; // bir head için bir WG (n_heads workgroup)
// ... causal mask yok, tüm cache_len attend- LSE save yok (decode'da backward yok)
- Workgroup grid daha küçük (
n_headsnotseq × n_heads) - KV pre-loaded cache; bizim KV buffer cache değil her step yeniden yazılan storage olabilir
Performance Notes
Bu kernel bottleneck. Profil ölçümünde:
attention_forward: 8.7% of step time (45 ms / 12 layers)- Her layer ~3.75 ms
Sebep:
- K access pattern coalesced değil (adjacent thread'ler K[k] okuyor, k varies)
- V access pattern aynı sorun
- Causal mask son block'larda thread idle oluşturuyor
Flash-Attention Deneyleri ve Occupancy Kaybı:
Kod tabanında Flash-style bir Q-block tiled varyantı (Br=8, Bc=32, K/V tile cache) denenmiştir. Ancak Apple GPU'larında bu yaklaşım 21 KB workgroup belleği gerektirdiği için occupancy oranını ~32 WG/SM'den ~1 WG/SM'e düşürmüş ve K/V yük amortismanından elde edilen kazançtan daha fazla gecikme gizleme (latency hiding) kaybına yol açmıştır. d=128 h=4 kv=2 swiglu seq=512 konfigürasyonunda adım süresinde net +%47 artışa neden olduğu için kernel, sorgu başına (per-query) online softmax yapısına geri döndürülmüştür.
Optimizasyon fırsatı: Flash-attention pattern (Q ve K'yı tile et, K/V over chunks of Q). Perf-doc #1.
WGSL-Spesifik Notlar
1. wg_reduce_max ve wg_reduce_sum her thread'de aynı sonuç
00_shared.md detayında. Subgroup-accelerated, deterministic.
2. exp(NEG_INF) = 0
WGSL spec garantisi. Causal mask sonrası NEG_INF score → exp → 0 → softmax'a katkı sıfır. Doğru behaviour.
3. var<workgroup> 4 KB total
sh_q + sh_p + sh_acc + sh_partial = 4×256×4B = 4 KB. sh_max, sh_lse 1×4B = ihmaledilir. 32 KB limit'in %12'si.
4. select(b, a, cond) — WGSL ternary
let kl_end = select(kl_end_raw, WG, kl_end_raw > WG);"Eğer raw > WG, WG'i al; aksi raw'ı." Min(raw, WG) gibi, ama explicit.
Code Review
Bulgu 1: K access non-coalesced
| Risk | Açıklama |
|---|---|
| 🟡 perf | Score loop'u her thread K[k * kv_dim + ...] okuyor — k thread başına farklı, dolayısıyla adjacent thread'ler kv_dim byte apart memory'den okuyor. Coalesced değil. Pratik etki: ~%10-15 attention overhead. |
Mitigasyon: Flash-attention pattern. Şu an open work.
Bulgu 2: sh_p size 256 — head_dim > 256 fail
| Risk | Açıklama |
|---|---|
| 🟢 yok | MAX_HEAD_DIM = 256 constraint koruma. Bizim model 64, gelecekte 128 yapılabilir, hâlâ güvenli. |
Bulgu 3: nan_guard(O[d]) çıkışta var
| Risk | Açıklama |
|---|---|
| 🟢 yok | Eğer Q × K = ∞ olursa exp patlayabilir. nan_guard NaN/Inf'i 0'a çekiyor, downstream stable. |
Hızlı Kontrol Listesi
| Test Senaryosu | Durum |
|---|---|
| Causal mask doğru mu? (s = 0 sadece kendine attend) | ✅ k < causal_len check |
| GQA mapping doğru mu? | ✅ kv_h = h / q_per_group |
| Online softmax = standard softmax? | ⚠ formal test yok, pratikte loss curve makul |
head_dim = 256 edge case crash? | ✅ MAX_HEAD_DIM = 256 |
seq_len = 1 (decode-style) | ✅ block iteration loop'u atlar |
| Backward LSE doğru kullanılıyor mu? | ✅ 10_backward_attention.md'da kanıt |
Sonraki
06_activation.md — GeLU, SwiGLU. FFN'in non-linearity bileşeni.