llm.istanbul·Etüt
TR EN
Workbench →

attention_forward — Causal Multi-Head Attention with Online Softmax

Dosya: 05_attention.wgsl Pipeline adımı: Layer'ın kalbi. RoPE'd Q ve K + V → contextual mixing.

İki kernel:

  • attention_forward — training (full sequence)
  • attention_decode — inference (single query against KV cache)

Nedir bu ya?

Bir grup sohbetine geç kaldın ve 200 mesaj birikmiş. Şu anki mesajı anlamak için geri kaydırıp eski mesajları okuyorsun. Ama hepsini eşit önemde okumuyorsun: birkaç tanesi tam konuyla ilgili, gerisi muhabbet. Kafanda her eski mesaja sessizce bir ağırlık veriyorsun — "bu çok ilgili", "bunu boşver" — sonra o ağırlıklı karışımdan şu anki mesajın ne anlama geldiğini çıkarıyorsun. Attention'ın yaptığı tam bu.

Her token (kelime parçası) bir "sorgu" üretiyor: "ben kimi dinlemeliyim?" Diğer tokenlar da birer "anahtar" sunuyor: "ben şu konuya bakıyorsan ilgiliyim." Sorgu ile anahtarların eşleşmesi (dot product) bir skor veriyor, softmax bu skorları toplamı 1 olan ağırlıklara çeviriyor, ve her tokenın taşıdığı "değer" (value) bu ağırlıklarla harmanlanıp çıktıyı oluşturuyor. Yani fuzzy bir lookup: hash map'te tek bir anahtarla tek satır çekmek yerine, tüm satırları benzerliklerine göre ağırlıklı topluyorsun.

Bir kural var: causal. Token sadece kendinden önceki tokenlara bakabilir, sonrakilere değil — grup sohbetinde henüz yazılmamış mesajı okuyamazsın ya, onun gibi. İleriye bakmak yasak, çünkü model bir sonraki kelimeyi tahmin etmeyi öğreniyor; cevabı görmesine izin veremeyiz.

İşin zor kısmı, bu ağırlıkları (softmax) hesaplamak için normalde tüm skorları belleğe yazıp üç kez taramak gerekir — sekans uzadıkça bu pahalılaşır. Online softmax ise mesajları tek geçişte, akarken okur ve çalışırken sürekli kendini düzeltir. Aşağıda bunun GPU'da nasıl tek pass'e sığdırıldığına bakacağız.


Ne Yapar?

Causal scaled dot-product attention with GQA (Grouped-Query Attention) and online softmax:

score[s, k] = (Q[s] · K[k]) / sqrt(head_dim)
P[s, k] = softmax_k(score[s, k])    (k ≤ s, causal)
O[s] = Σ_k P[s, k] · V[k]

Modelin contextual mixing mekanizması. Token s her kendinden önceki token k'ya "ne kadar dikkat edeceğini" hesaplar (P[s,k] olasılığı), sonra V'lerin ağırlıklı toplamı ile output üretir.

Online Softmax — Niye?

Klasik softmax 3 pass:

  1. m = max(scores) — scan all scores
  2. exp_sum = Σ exp(score - m) — scan again
  3. P = exp(score - m) / exp_sum — scan again

Sequence uzunsa (örn 1024) 3× memory traffic + scratch buffer for scores. Online softmax bunu single-pass'e çevirir, scratch buffer O(WG + head_dim) olur, O(seq_len) değil.

GQA (Grouped Query Attention)

n_heads query, ama sadece n_kv_heads < n_heads key/value. Multiple Q heads, aynı K/V grubunu paylaşır:

n_heads = 12, n_kv_heads = 4
→ q_per_group = 3
→ Query head h uses kv_head h/3

Avantaj: KV cache 1/q_per_group boyutunda. Bizim modelde n_heads=12 / n_kv=4 → KV ⅓ boyutu. Az kalite kaybı, büyük memory tasarrufu.


Online Softmax Algoritması — Çok Önemli

Sequence içindeki K-block'larını sırayla işleyip state'i incremental günceller:

state: m (running max), l (running exp-sum), acc[d] (running V·P)
init: m = -∞, l = 0, acc[d] = 0

for each block of K (size WG=256):
  1. compute scores = (Q · K_block[k]) * scale  for each k in block
  2. block_max = max(scores in block)
  3. m_new = max(m, block_max)
  4. P[k] = exp(score[k] - m_new)
  5. block_sum = Σ P[k]
  6. correction = exp(m - m_new)             ← old state'i scale
  7. l ← l * correction + block_sum
  8. acc[d] ← acc[d] * correction + Σ_k P[k] · V[k, d]
  9. m ← m_new

final:
  O[d] = acc[d] / l
  LSE = m + log(l)                            ← saved for backward

correction = exp(m_old - m_new) term, eğer m_new > m_old ise eski state'i m_new'e göre renormalize ediyor. Bu sayede stable, single-pass.

LSE = log(Σ exp(scores)) saved for backward — backward'da bunu reuse edip recompute'tan kaçınıyoruz.


Bind Group ABI

attention_forward (8 binding)

BindingTürDetay
0storage, readQ: array<f32>[seq × n_heads × head_dim]
1storage, readK: array<f32>[seq × n_kv × head_dim] (GQA: smaller)
2storage, readV: array<f32>[seq × n_kv × head_dim]
3storage, read_writeO: array<f32>[seq × n_heads × head_dim]
4storage, read_writeLSE: array<f32>[seq × n_heads] saved for backward
5uniformdims: vec4<u32>(seq_len, n_heads, n_kv_heads, head_dim)
6uniformparams: vec4<f32>(scale, _, _, _)
7storage, readseg: array<u32> — her sorgunun döküman başlangıç indeksi (cross-doc mask)

scale = 1 / sqrt(head_dim) host'ta hesaplanır.

attention_decode (6 binding)

Decode için Q tek bir row ([n_heads × head_dim]), K/V cache'ten okunur, LSE save yoktur.

BindingTürDetay
0storage, readQ_one: array<f32>[n_heads × head_dim] (tek satır)
1storage, readK_cache: array<f32>[seq × n_kv × head_dim] (geçmiş anahtarlar cache'i)
2storage, readV_cache: array<f32>[seq × n_kv × head_dim] (geçmiş değerler cache'i)
3storage, read_writeO_one: array<f32>[n_heads × head_dim] (tek çıkış satırı)
4uniformdims_dec: vec4<u32>(cache_len, n_heads, n_kv_heads, head_dim)
5uniformparams_dec: vec4<f32>(scale, _, _, _)

Workgroup Memory Layout

wgsl
var<workgroup> sh_q:       array<f32, 256>;  // Q row cache
var<workgroup> sh_p:       array<f32, 256>;  // current K-block scores → P values
var<workgroup> sh_acc:     array<f32, 256>;  // running output accumulator
var<workgroup> sh_partial: array<f32, 256>;  // V·P partial sum reduction scratch
var<workgroup> sh_max:     array<f32, 1>;    // running max m (single value)
var<workgroup> sh_lse:     array<f32, 1>;    // running sum-exp l (single value)

Total: ~4 KB workgroup memory. Apple GPU 32 KB limit'inin %12'si — rahat occupancy.

MAX_HEAD_DIM = 256 constraint: head_dim 256'yı aşamaz. Bizim modelde head_dim=64 (12 head × 64 = 768 d_model), bol bol yer.


Dispatch Şekli

workgroup_size: 256
grid:           (seq_len, n_heads, 1) workgroups

Bir workgroup = bir (s, h) çifti, yani bir query position bir head için.

Örnek (seq=512, n_heads=12):

  • 6144 WG × 256 thread = 1.57M thread
  • Each WG iterates ceil(causal_len / 256) blocks of K

Satır Satır — attention_forward

1) Decode workgroup ID

wgsl
let s = wgid.x;        // query position
let h = wgid.y;        // attention head
// ...
let q_per_group = n_heads / n_kv_heads;
let kv_h = h / q_per_group;          // GQA: query head h uses kv_h

GQA mapping. h=0,1,2 → kv_h=0; h=3,4,5 → kv_h=1; vs.

2) Phase 0 — Cache Q, init state

wgsl
let q_row_base = s * d_model + h * head_dim;
let causal_len = s + 1u;             // attend over [0, s]
let doc_start = seg[s];              // bu sorgunun ait olduğu dökümanın başlangıç indeksi
  • causal_len = s + 1u → bu query position kendi dahil önceki tüm pozisyonlara dikkat edebilir, sonrakilere edemez.
  • doc_start = seg[s] → Çapraz döküman maskelemesi: sorgu s yalnızca kendi döküman sınırları içindeki k anahtarlarına dikkat edebilir. Döküman başlangıcından önceki (k < doc_start) anahtarlar maskelenir.
  • sh_q[i] = Q[...] — Q row'unun head_dim boyutunda copy'sini workgroup memory'ye al. Inner-loop'ta her thread aynı Q'ya erişecek; once-load, many-read.

3) K-block parallelization parameters

wgsl
let k_par      = max(1u, WG / max(1u, head_dim));  // 4 when head_dim=64
let n_active   = head_dim * k_par;                  // 256 when head=64,k_par=4
let d_local    = tid / k_par;                       // d ∈ [0, head_dim)
let kc         = tid % k_par;                       // chunk index ∈ [0, k_par)
let chunk_size = (WG + k_par - 1u) / k_par;         // 64

Phase D (V·P partial)'de WG'i 2D olarak yorumla: (d_local, kc) mesh'i. d_local çıkış dimension'ı (0..head_dim-1), kc chunk index (0..k_par-1).

Eğer head_dim < WG, multiple thread'ler aynı d için farklı chunk'ları paralel hesaplar.

4) K-block iteration (Döküman Öncesini Atlama)

wgsl
let first_blk = doc_start / WG;
let n_blocks = (causal_len + WG - 1u) / WG;
for (var blk = first_blk; blk < n_blocks; blk = blk + 1u) {
    let k_start = blk * WG;
    let k = k_start + tid;
  • first_blk = doc_start / WG → Döküman başlangıcının altındaki K-block'ları tamamen atlanır! Bu sayede hem bellek okumalarından tasarruf edilir hem de ilk işlenen blokta en az bir adet geçerli (maskelenmemiş) anahtar bulunması garanti edilir. Bu garanti, online softmax hesaplamalarındaki running-max değerinin -∞ kalmasını ve running-max düzeltmesinde exp(-∞ - -∞) = NaN hatası üretilmesini (LSE stabilizasyonu) engeller.

5) Phase A — Score computation (Maskelenmiş Konumlar)

wgsl
var score: f32 = NEG_INF;
if (k >= doc_start && k < causal_len) {
    let k_row_base = k * kv_dim + kv_h * head_dim;
    var dot: f32 = 0.0;
    for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
        dot = fma(sh_q[d], K[k_row_base + d], dot);
    }
    score = dot * scale;
}

Her thread bir k için score hesaplar:

  • Q (workgroup mem'den) ⋅ K[k] (global mem'den)
  • head_dim adet FMA
  • Scale = 1/sqrt(head_dim) (LLaMA-style)
  • if (k >= doc_start && k < causal_len) ile hem döküman maskelemesi hem de causal mask tek satırda uygulanır. Koşulu sağlamayan konumlar NEG_INF alır, böylece softmax'ta exp(NEG_INF) = 0 katkı üretirler.

Memory pattern: Adjacent thread'ler adjacent K rows okuyor (k=0,1,2,...) — K[k * kv_dim + kv_h * head_dim + d]. Aynı d için adjacent k → scattered (kv_dim apart). Coalesced değil. Bu attention forward'ın performance handicap'i — flash attention forward (perf-doc #1) bunu fix'leyebilir.

6) Phase B — Block max + running max

wgsl
let block_max = wg_reduce_max(tid, score);
let m_old = sh_max[0];
let m_new = max(m_old, block_max);

wg_reduce_max tüm 256 thread'in max'ını döndürür — her thread'de aynı değer.

7) Phase C — P[k] + block sum + correction

wgsl
var p_val: f32 = 0.0;
if (k < causal_len) {
    p_val = exp(score - m_new);
}
sh_p[tid] = p_val;
let block_sum = wg_reduce_sum(tid, p_val);
let correction = exp(m_old - m_new);
let l_old = sh_lse[0];
let l_new = l_old * correction + block_sum;

p_val = exp(score - m_new) — softened score. m_new şu ana kadar görülen max, exp(score - max) ≤ 1 numerical stability için.

correction = exp(m_old - m_new):

  • İlk iteration: m_old = -∞, correction = exp(-∞ - m_new) → 0 (eski state'i sıfırla, çünkü o zaman bilgi yoktu)
  • Sonraki iterations: m artarsa correction < 1, eski accumulators down-scale

8) Phase D — V·P contribution

wgsl
if (tid < n_active) {
    let kl_start   = kc * chunk_size;
    let kl_end_raw = kl_start + chunk_size;
    let kl_end     = select(kl_end_raw, WG, kl_end_raw > WG);

    var partial: f32 = 0.0;
    for (var kl = kl_start; kl < kl_end; kl = kl + 1u) {
        let kk = k_start + kl;
        if (kk < causal_len) {
            partial = fma(sh_p[kl], V[kk * kv_dim + kv_h * head_dim + d_local], partial);
        }
    }
    sh_partial[d_local * k_par + kc] = partial;
}
workgroupBarrier();

if (kc == 0u && tid < n_active) {
    var sum_pv: f32 = sh_partial[d_local * k_par];
    for (var i = 1u; i < k_par; i = i + 1u) {
        sum_pv = sum_pv + sh_partial[d_local * k_par + i];
    }
    sh_acc[d_local] = sh_acc[d_local] * correction + sum_pv;
}

Phase D iki adım:

  1. Her (d_local, kc) thread'i kendi chunk'ındaki Σ P[kl] · V[k_start+kl, d_local] partial'ı hesaplar
  2. Barrier sonrası, kc == 0u thread'leri kendi d_local için partial'ları toplar (k_par adet) → sh_acc[d_local]'a yaz

Niye iki phase? Doğrudan tek thread'te Σ_kl yapsak head_dim thread çalışıp diğerleri boşta dururdu. Multi-thread reduction: 4 (k_par) thread paralel partial → 1 thread merge. Daha hızlı.

9) Update m, l

wgsl
if (tid == 0u) {
    sh_max[0] = m_new;
    sh_lse[0] = l_new;
}
workgroupBarrier();

Sadece tid 0 yazar. Single writer, race yok. Barrier sonrası tüm thread'ler güncellenmiş m/l görür.

10) Final O ve LSE

wgsl
let l_final = max(sh_lse[0], 1e-30);
let inv_l = 1.0 / l_final;
if (tid < head_dim) {
    O[s * d_model + h * head_dim + tid] = nan_guard(sh_acc[tid] * inv_l);
}
if (tid == 0u) {
    LSE[s * n_heads + h] = sh_max[0] + log(l_final);
}

max(l, 1e-30) — division-by-zero koruması. Eğer hiçbir K hit ettiyse l=0 kalabilir (örneğin causal_len=0 edge case, s=0 dahil değil). Pratikte tetiklenmez.

O[d] = acc[d] / l — softmax denominator division. LSE = m + log(l) — backward'da reuse edilecek, log-sum-exp.


attention_decode Farkı

Inference için. Q tek row (mevcut token), K/V cache'ten okunur (önceki tüm token'ların):

wgsl
let h = wgid.x;             // bir head için bir WG (n_heads workgroup)
// ... causal mask yok, tüm cache_len attend
  • LSE save yok (decode'da backward yok)
  • Workgroup grid daha küçük (n_heads not seq × n_heads)
  • KV pre-loaded cache; bizim KV buffer cache değil her step yeniden yazılan storage olabilir

Performance Notes

Bu kernel bottleneck. Profil ölçümünde:

  • attention_forward: 8.7% of step time (45 ms / 12 layers)
  • Her layer ~3.75 ms

Sebep:

  1. K access pattern coalesced değil (adjacent thread'ler K[k] okuyor, k varies)
  2. V access pattern aynı sorun
  3. Causal mask son block'larda thread idle oluşturuyor
Not

Flash-Attention Deneyleri ve Occupancy Kaybı: Kod tabanında Flash-style bir Q-block tiled varyantı (Br=8, Bc=32, K/V tile cache) denenmiştir. Ancak Apple GPU'larında bu yaklaşım 21 KB workgroup belleği gerektirdiği için occupancy oranını ~32 WG/SM'den ~1 WG/SM'e düşürmüş ve K/V yük amortismanından elde edilen kazançtan daha fazla gecikme gizleme (latency hiding) kaybına yol açmıştır. d=128 h=4 kv=2 swiglu seq=512 konfigürasyonunda adım süresinde net +%47 artışa neden olduğu için kernel, sorgu başına (per-query) online softmax yapısına geri döndürülmüştür.

Optimizasyon fırsatı: Flash-attention pattern (Q ve K'yı tile et, K/V over chunks of Q). Perf-doc #1.


WGSL-Spesifik Notlar

1. wg_reduce_max ve wg_reduce_sum her thread'de aynı sonuç

00_shared.md detayında. Subgroup-accelerated, deterministic.

2. exp(NEG_INF) = 0

WGSL spec garantisi. Causal mask sonrası NEG_INF score → exp → 0 → softmax'a katkı sıfır. Doğru behaviour.

3. var<workgroup> 4 KB total

sh_q + sh_p + sh_acc + sh_partial = 4×256×4B = 4 KB. sh_max, sh_lse 1×4B = ihmaledilir. 32 KB limit'in %12'si.

4. select(b, a, cond) — WGSL ternary

wgsl
let kl_end = select(kl_end_raw, WG, kl_end_raw > WG);

"Eğer raw > WG, WG'i al; aksi raw'ı." Min(raw, WG) gibi, ama explicit.


Code Review

Bulgu 1: K access non-coalesced

RiskAçıklama
🟡 perfScore loop'u her thread K[k * kv_dim + ...] okuyor — k thread başına farklı, dolayısıyla adjacent thread'ler kv_dim byte apart memory'den okuyor. Coalesced değil. Pratik etki: ~%10-15 attention overhead.

Mitigasyon: Flash-attention pattern. Şu an open work.

Bulgu 2: sh_p size 256 — head_dim > 256 fail

RiskAçıklama
🟢 yokMAX_HEAD_DIM = 256 constraint koruma. Bizim model 64, gelecekte 128 yapılabilir, hâlâ güvenli.

Bulgu 3: nan_guard(O[d]) çıkışta var

RiskAçıklama
🟢 yokEğer Q × K = ∞ olursa exp patlayabilir. nan_guard NaN/Inf'i 0'a çekiyor, downstream stable.

Hızlı Kontrol Listesi

Test SenaryosuDurum
Causal mask doğru mu? (s = 0 sadece kendine attend)✅ k < causal_len check
GQA mapping doğru mu?✅ kv_h = h / q_per_group
Online softmax = standard softmax?⚠ formal test yok, pratikte loss curve makul
head_dim = 256 edge case crash?✅ MAX_HEAD_DIM = 256
seq_len = 1 (decode-style)✅ block iteration loop'u atlar
Backward LSE doğru kullanılıyor mu?10_backward_attention.md'da kanıt

Sonraki

06_activation.md — GeLU, SwiGLU. FFN'in non-linearity bileşeni.

WGSL kernel etüdleri · WebGPU üzerinde sıfırdan LLMİstanbul’da Uğur Toprakdeviren tarafından hazırlandı.