adamw_update, reduce_norm_sq, finalize_grad_stats — AdamW Optimizer
Dosya: 12_optimizer.wgsl Pipeline adımı: Pipeline'ın son adımı. Backward bittikten sonra weight'leri günceller.
5 kernel:
reduce_norm_sq—Σ g²global gradient norm reductionfinalize_grad_stats—gradNorm = √Σ g²,clipScale = min(1, max_norm / gradNorm)adamw_update— fused AdamW + clip + grad zero (saf fp32)adamw_update_f16— fused AdamW + clip + grad zero + f16 forward mirror castadamw_8bit_update— 8-bit quantized state variant (memory tasarrufu için)
Ayrıca bizim WGSL'da multi-tensor desteği — tek kernel call mega-buffer üzerinde çalışır, sub-range binding'lerle her parametreyi adresliyor.
Nedir bu ya?
Bütün forward + backward gürültüsünden sonra elinde tek bir şey kalıyor: her ağırlık için bir gradyan — "bu sayıyı şu yöne itersen loss biraz düşer" diyen bir ok. Optimizer'ın tek işi o okları alıp ağırlıkları gerçekten itmek. En aptal hâli w = w - lr * grad: okun gösterdiği yöne küçük bir adım at. Buna SGD deniyor ve çalışıyor ama dağ yolunda gözü kapalı yürümek gibi — her adımda zikzak çiziyorsun.
AdamW bu yürüyüşe iki tane hafıza ekliyor. Düşün ki bir topu yokuş aşağı yuvarlıyorsun: top tek bir gradyan okuna körü körüne uymaz, momentum taşır — son birkaç adımda hangi yöne gittiğine bakar (bu m). İkincisi, her parametrenin gradyanı ne kadar zıpkın onu da takip eder (bu v): sürekli sağa-sola savrulan bir ağırlığa temkinli, sakin ilerleyen birine cömert adım atar. İkisi de "exponential moving average" — yani eski bilgiyi yavaşça unutan kayan ortalama, tıpkı bir uygulamadaki "son 7 günün ortalaması" göstergesi gibi.
Bir incelik daha var: weight decay (ağırlıkları hafifçe sıfıra doğru çekme, overfit'i frenler). AdamW'nin "W"si tam burada — bu çekmeyi gradyan hesabının içine karıştırmaz, ayrı ve temiz bir terim olarak en sonda uygular ("decoupled"). Eski Adam'ın yaptığı hata buydu; AdamW düzeltir.
Geri kalan her şey muhasebe ve hız. Modelde ~74 ayrı ağırlık tensörü var; her birine ayrı ayrı kernel atmak yerine hepsini tek dev buffer'da yan yana diziyoruz (multi-tensor), ve fp16 aynasını da aynı geçişte güncelleyip ekstra bir tam taramadan kurtuluyoruz (fused mirroring). Asıl ince iş aşağıda başlıyor.
Multi-tensor AdamW — En Önemli Optimizasyon
Klasik (per-tensor) AdamW:
for each parameter tensor (~74 in our model):
adamw_update kernel call
74 dispatch / step → enormous overheadBizim multi-tensor:
1 kernel call → mega_w/g/m/v contiguous buffers
sub-range bindings: per-param view via offset+sizeMega-buffer:
mega_w: array<f32>— TÜM weight'ler tek buffer'da concatenatemega_g,mega_m,mega_v— aynı yapı
Sub-range binding:
{ buffer: mega_w, offset: param_offset[i], size: param_size[i] }WebGPU bunu tek param tensor gibi görür ama physical olarak mega-buffer'ın bir slice'ı.
Sonuç:
- 74 → 2 dispatch (wd block + no-wd block)
- 50 ms → 13.7 ms optimizer phase
- −36 ms / step
AdamW Math (Decoupled Weight Decay)
Klasik AdamW:
m ← β₁·m + (1-β₁)·g·clip
v ← β₂·v + (1-β₂)·(g·clip)²
m̂ = m / (1 - β₁ᵗ) ← bias correction
v̂ = v / (1 - β₂ᵗ)
w ← w - lr·( m̂ / (√v̂ + ε) + λ·w )
↑ decoupled weight decayg·clip = grad × clipScale (gradient clipping fused).
λ (weight_decay) decoupled — momentum'lu update'le karışmıyor, separate term.
Niye decoupled? Klasik Adam (λ·w momentum'a dahil) weight'i de momentum gibi update ediyor — bu değil "decay", "shrinkage" oluyor. AdamW (Loshchilov 2017) bu hatayı düzeltiyor: λ·w direct olarak weight'ten çıkarılır.
Pipeline — 3 Adım
1. reduce_norm_sq — Global gradient norm
input: mega_g (all gradients concatenated)
output: norm_sq scalar (atomic accumulator)
for each gradient element:
if finite: accumulate g²
return Σ g²Atomik CAS-add scalar slot'una. Tek scalar — ne kadar contended olsa da küçük buffer.
2. finalize_grad_stats
gradNorm = √Σg²
clipScale = min(1, max_grad_norm / max(gradNorm, 1e-6))Single thread. Sonuç grad_stats[0..1] array'inde.
max_grad_norm host'tan uniform (örn 1.0). clipScale ≤ 1. Eğer norm < max_norm → clipScale=1 (no clipping). Aksi halde clipScale < 1 → clip applied.
3. adamw_update
Her thread bir element için:
1. read grad, finite-check, multiply by clip_scale
2. update m, v
3. apply bias correction
4. compute weight update
5. write w_new, m_new, v_new
6. zero grad (for next iter)adamw_update mega_w/g/m/v üzerinde çalışır. Sub-range binding ile tek dispatch tüm tensorları update ediyor.
Bind Group ABI
adamw_update (8 binding — Saf fp32)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read_write | w: array<f32> — master weights |
| 1 | storage, read_write | g: array<f32> — gradients (zero'd after) |
| 2 | storage, read_write | m_buf: array<f32> — 1st moment |
| 3 | storage, read_write | v_buf: array<f32> — 2nd moment |
| 4 | storage, read | grad_stats: array<f32> — [gradNorm, clipScale] |
| 5 | uniform | size: u32 |
| 6 | uniform | hp: vec4<f32> — (lr, beta1, beta2, wd) |
| 7 | uniform | bias: vec4<f32> — (beta1_t, beta2_t, eps, _) |
adamw_update_f16 (9 binding — Mixed Precision / Fused Mirroring)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read_write | w_mp: array<f32> — master weights |
| 1 | storage, read_write | g_mp: array<f32> — gradients (zero'd after) |
| 2 | storage, read_write | m_buf_mp: array<f32> — 1st moment |
| 3 | storage, read_write | v_buf_mp: array<f32> — 2nd moment |
| 4 | storage, read | grad_stats_mp: array<f32> — [gradNorm, clipScale] |
| 5 | uniform | size_mp: u32 |
| 6 | uniform | hp_mp: vec4<f32> |
| 7 | uniform | bias_mp: vec4<f32> |
| 8 | storage, read_write | dst_w16: array<f16> — fp16 formatında forward ağırlık aynası (fused mirror) |
adamw_8bit_update (10 binding — 8-bit Quantized State)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read_write | w: array<f32> — master weights |
| 1 | storage, read_write | g: array<f32> — gradients (zero'd after) |
| 2 | storage, read_write | m_packed: array<u32> — packed int8 1st moment (4 values per u32) |
| 3 | storage, read_write | v_packed: array<u32> — packed int8 2nd moment (4 values per u32) |
| 4 | storage, read_write | m_scale: array<f32> — per-block (256 elements) scale factor |
| 5 | storage, read_write | v_scale: array<f32> — per-block (256 elements) scale factor |
| 6 | storage, read | grad_stats: array<f32> — [gradNorm, clipScale] |
| 7 | uniform | size: u32 |
| 8 | uniform | hp: vec4<f32> — (lr, b1, b2, wd) |
| 9 | uniform | bias: vec4<f32> — (b1_t, b2_t, eps, _) |
beta1_t = beta1^step, beta2_t = beta2^step host'ta hesaplanır.
Dispatch Şekli
reduce_norm_sq
workgroup_size: 256
grid: ceil(total_params / 256)Total parametre sayısı — örn 116M. Her thread strided loop ile element işler, workgroup tree-reduction, atomic-add.
adamw_update
workgroup_size: 256
grid: ceil(size / 256)size mega-buffer slice'ın eleman sayısı. Multi-tensor pattern:
- Compute beta_t for current step (host)
- Bind mega_w/g/m/v sub-range for "wd block" (decay applied)
- Dispatch
adamw_updatefor wd block - Bind sub-range for "no-wd block" (norms, biases — wd skipped)
- Dispatch
adamw_updatefor no-wd block
Toplam 2 dispatch yerine 74 (per-tensor).
Satır Satır — adamw_update
1) Setup
let i = flat_id(gid, nwg);
if (i >= size) { return; }
let lr = hp.x;
let beta1 = hp.y;
let beta2 = hp.z;
let wd = hp.w;
let beta1_t = bias.x;
let beta2_t = bias.y;
let eps = bias.z;
let clip = grad_stats[1];Tüm hyperparameters uniform'dan gelir. clip runtime'da hesaplandı (önceki kernel'da).
2) Grad finite-check + clip
var grad = g[i];
grad = select(0.0, grad, is_finite(grad));
grad = grad * clip;NaN/Inf grad → 0'a çek (gradient explosion'dan korunma). Sonra clip ile multiply (norm'u max_norm'un altına).
3) Adam moments
let m_old = m_buf[i];
let v_old = v_buf[i];
let w_old = w[i];
let m_new = beta1 * m_old + (1.0 - beta1) * grad;
let v_new = beta2 * v_old + (1.0 - beta2) * grad * grad;EMA updates:
m: gradient'ın exponential moving average (1st moment, "velocity")v: gradient²'nin EMA (2nd moment, "variance")
beta1 ≈ 0.9, beta2 ≈ 0.999 default. m hızlı reaktif (recent grads), v yavaş (long-term grad²).
4) Bias correction
let b1_corr = 1.0 / max(1.0 - beta1_t, 1e-12);
let b2_corr = 1.0 / max(1.0 - beta2_t, 1e-12);
let m_hat = m_new * b1_corr;
let v_hat = v_new * b2_corr;Erken step'lerde m, v ≈ 0 (zero-init'ten geliyor). Bias correction bunu compansate eder. Step ilerledikçe beta_t → 0, correction → 1.
5) Update + decoupled WD
let denom = sqrt(v_hat) + eps;
let step = m_hat / denom + wd * w_old;
let w_new = w_old - lr * step;wd * w_old — decoupled weight decay. Adam moments-dependent değil.
6) Write back
w[i] = w_new;
m_buf[i] = m_new;
v_buf[i] = v_new;
g[i] = 0.0; ← grad zero for next stepGrad zero fused. Ayrı fill_zero kernel'ı çağrısı yok. Bandwidth tasarrufu.
reduce_norm_sq — Atomic Reduction
let grid_size = nwg.x * nwg.y * WG;
var local: f32 = 0.0;
var i = flat_id(gid, nwg);
loop {
if (i >= n) { break; }
let v = data[i];
if (is_finite(v)) { local = local + v * v; }
i = i + grid_size;
}
sh_red[tid] = local;
workgroupBarrier();
// classical tree reduction
var s = WG / 2u;
loop {
if (s == 0u) { break; }
if (tid < s) { sh_red[tid] = sh_red[tid] + sh_red[tid + s]; }
workgroupBarrier();
s = s >> 1u;
}
if (tid == 0u) {
// CAS atomic add to global accumulator
...
}Niye wg_reduce_sum değil tree reduction?
wg_reduce_sum subgroup operations gerektirir. Tüm GPU'larda subgroup destek yok (eski browsers). Tree reduction subgroups extension'a bağımlı değil — portable. Optimizer kernel mobile/older hardware için tasarlandı.
Hot path değil — global norm sadece 1 kez per step. Ekstra barrier'ları kabul ediyoruz.
adamw_8bit_update
Production'da kullanılmayan alternative kernel. Adam state'i (m, v) packed u32 quantized + scale factors olarak tutar:
- Memory savings: m + v 8-bit ile ~%75 azalır (4 bytes → 1 byte per state)
- Quality cost: ~%1 worse final loss
8-bit dynamic quantization — block-wide max → scale factor → quantize 8-bit signed. Detay 12_optimizer.wgsl'da ~80+ line kernel.
Bizim default fp32 AdamW. 8-bit fallback gerektiğinde (memory budget düşük) UI'dan seçilebilir.
WGSL-Spesifik Notlar
1. Sub-range binding
WebGPU bindGroupLayout standard. entries[].resource.{buffer, offset, size} ile bir buffer'ın slice'ı bind edilebilir. Bizim engine bunu kullanarak mega-buffer üzerinde per-param view oluşturuyor.
2. atomic<u32> declaration
@group(0) @binding(1) var<storage, read_write> result: atomic<u32>;Reduce kernel'da global accumulator atomic-tagged. Per-parameter scalar (1 element).
3. nwg.x * nwg.y * WG grid size
let grid_size = nwg.x * nwg.y * WG;2D grid'i lineerize ediyor (1D fallback için flat_id kullanımı). Total thread sayısı.
4. select(0.0, grad, is_finite(grad))
WGSL ternary. is_finite true ise grad, değilse 0. Standard NaN/Inf guard.
Performance — Profile snapshot'tan
adamw_update (× 2) 35.78 ms → 6.9% step
scale (× 1) 8.91 ms → 1.7% step (gradient clip apply)
cast_f32_to_f16 (×110) 8.59 ms → 1.7% step (mixed precision sync)
reduce_norm_sq (× 1) 4.26 ms → 0.8% step
finalize_grad_stats (× 1) 0.07 ms
───────
Total ~57 ms = 11% stepadamw_update 35.8 ms / 2 calls = 17.9 ms per call. Yüksek. Multi-tensor AdamW'ın SUB-LINEAR olması beklenirdi ama 17.9 ms her call yine de büyük (mega_w 100M+ element üzerinde dispatch).
Tamamlanan Optimizasyon (perf-doc #B): cast_f32_to_f16 standalone sweep'i tamamen ekarte edilerek adamw_update_f16 kernel'ı ile optimizer içerisine gömülmüştür. Bu sayede katman başına parametreler üzerinde fazladan koşan ~110 dispatch engellenerek ~%1.7 step süresi kazancı doğrudan elde edilmiştir.
Code Review
Bulgu 1: 2 dispatch yerine 1 olamaz mı?
| Risk | Açıklama |
|---|---|
| 🟡 minor | Şu an wd-block + no-wd-block ayrı dispatch. Tek dispatch'te select(wd, 0.0, is_norm_or_bias) yapılabilir. Ek register/uniform gerekir, marjinal speedup. |
Bulgu 2: Bias correction precision
| Risk | Açıklama |
|---|---|
| 🟢 minor | 1 - beta_t küçük olunca 1/(1-beta_t) büyük olur. Erken step'lerde precision loss riski. max(..., 1e-12) clamp guard. |
Bulgu 3: Grad zero fused — efficient
| Risk | Açıklama |
|---|---|
| 🟢 yok | g[i] = 0.0 end-of-kernel. Ayrı fill_zero call gereği yok. |
Hızlı Kontrol Listesi
| Test Senaryosu | Durum |
|---|---|
| Adam moments default zero-init mi? | ✅ allocate'ta fill_zero |
step = 1 bias correction sonsuz değil mi? | ✅ max(1e-12) clamp |
| NaN grad → 0 cleanup? | ✅ select |
| Mega-buffer sub-range binding doğru mu? | ✅ runtime test passing |
wd_block vs no_wd_block ayrımı doğru mu? | ✅ host-side param layout |
| 8-bit kernel parity ile fp32 mı? | ⚠ %1-2 worse final loss accepted |
Pipeline Tamamlandı
Bu son chapter. Forward → Backward → Optimizer akışının her kernel'i incelendi. Aynı corpus üzerinde her training step şu sırayla çalışır:
01 Embed → 02 Norm → 03 Linear (Q/K/V/O) → 04 RoPE → 05 Attention
→ 02 Norm → 03 Linear (gate/up) → 06 SwiGLU → 03 Linear (down) → residual
... 12 layer ...
→ 02 Norm → 03 Linear (lm_head) → 07 Loss + dLogits
← 11 Linear bwd ← 09 FFN bwd ← 10 Attn bwd ← 02 Norm bwd ← 01 Embed bwd ...
→ reduce_norm_sq → finalize_grad_stats → adamw_update_f16 (fused) → next stepHer parça kendi md dokümanında. Geri dön index'e: index.md.