llm.istanbul·Etüt
TR EN
Workbench →

adamw_update, reduce_norm_sq, finalize_grad_stats — AdamW Optimizer

Dosya: 12_optimizer.wgsl Pipeline adımı: Pipeline'ın son adımı. Backward bittikten sonra weight'leri günceller.

5 kernel:

  • reduce_norm_sqΣ g² global gradient norm reduction
  • finalize_grad_statsgradNorm = √Σ g², clipScale = min(1, max_norm / gradNorm)
  • adamw_update — fused AdamW + clip + grad zero (saf fp32)
  • adamw_update_f16 — fused AdamW + clip + grad zero + f16 forward mirror cast
  • adamw_8bit_update — 8-bit quantized state variant (memory tasarrufu için)

Ayrıca bizim WGSL'da multi-tensor desteği — tek kernel call mega-buffer üzerinde çalışır, sub-range binding'lerle her parametreyi adresliyor.


Nedir bu ya?

Bütün forward + backward gürültüsünden sonra elinde tek bir şey kalıyor: her ağırlık için bir gradyan — "bu sayıyı şu yöne itersen loss biraz düşer" diyen bir ok. Optimizer'ın tek işi o okları alıp ağırlıkları gerçekten itmek. En aptal hâli w = w - lr * grad: okun gösterdiği yöne küçük bir adım at. Buna SGD deniyor ve çalışıyor ama dağ yolunda gözü kapalı yürümek gibi — her adımda zikzak çiziyorsun.

AdamW bu yürüyüşe iki tane hafıza ekliyor. Düşün ki bir topu yokuş aşağı yuvarlıyorsun: top tek bir gradyan okuna körü körüne uymaz, momentum taşır — son birkaç adımda hangi yöne gittiğine bakar (bu m). İkincisi, her parametrenin gradyanı ne kadar zıpkın onu da takip eder (bu v): sürekli sağa-sola savrulan bir ağırlığa temkinli, sakin ilerleyen birine cömert adım atar. İkisi de "exponential moving average" — yani eski bilgiyi yavaşça unutan kayan ortalama, tıpkı bir uygulamadaki "son 7 günün ortalaması" göstergesi gibi.

Bir incelik daha var: weight decay (ağırlıkları hafifçe sıfıra doğru çekme, overfit'i frenler). AdamW'nin "W"si tam burada — bu çekmeyi gradyan hesabının içine karıştırmaz, ayrı ve temiz bir terim olarak en sonda uygular ("decoupled"). Eski Adam'ın yaptığı hata buydu; AdamW düzeltir.

Geri kalan her şey muhasebe ve hız. Modelde ~74 ayrı ağırlık tensörü var; her birine ayrı ayrı kernel atmak yerine hepsini tek dev buffer'da yan yana diziyoruz (multi-tensor), ve fp16 aynasını da aynı geçişte güncelleyip ekstra bir tam taramadan kurtuluyoruz (fused mirroring). Asıl ince iş aşağıda başlıyor.


Multi-tensor AdamW — En Önemli Optimizasyon

Klasik (per-tensor) AdamW:

for each parameter tensor (~74 in our model):
    adamw_update kernel call
74 dispatch / step → enormous overhead

Bizim multi-tensor:

1 kernel call → mega_w/g/m/v contiguous buffers
sub-range bindings: per-param view via offset+size

Mega-buffer:

  • mega_w: array<f32> — TÜM weight'ler tek buffer'da concatenate
  • mega_g, mega_m, mega_v — aynı yapı

Sub-range binding:

javascript
{ buffer: mega_w, offset: param_offset[i], size: param_size[i] }

WebGPU bunu tek param tensor gibi görür ama physical olarak mega-buffer'ın bir slice'ı.

Sonuç:

  • 74 → 2 dispatch (wd block + no-wd block)
  • 50 ms → 13.7 ms optimizer phase
  • −36 ms / step

AdamW Math (Decoupled Weight Decay)

Klasik AdamW:

m  ← β₁·m + (1-β₁)·g·clip
v  ← β₂·v + (1-β₂)·(g·clip)²
m̂  = m / (1 - β₁ᵗ)        ← bias correction
v̂  = v / (1 - β₂ᵗ)
w  ← w - lr·( m̂ / (√v̂ + ε) + λ·w )
                          ↑ decoupled weight decay

g·clip = grad × clipScale (gradient clipping fused). λ (weight_decay) decoupled — momentum'lu update'le karışmıyor, separate term.

Niye decoupled? Klasik Adam (λ·w momentum'a dahil) weight'i de momentum gibi update ediyor — bu değil "decay", "shrinkage" oluyor. AdamW (Loshchilov 2017) bu hatayı düzeltiyor: λ·w direct olarak weight'ten çıkarılır.


Pipeline — 3 Adım

1. reduce_norm_sq — Global gradient norm

input:  mega_g (all gradients concatenated)
output: norm_sq scalar (atomic accumulator)

for each gradient element:
  if finite: accumulate g²
return Σ g²

Atomik CAS-add scalar slot'una. Tek scalar — ne kadar contended olsa da küçük buffer.

2. finalize_grad_stats

gradNorm  = √Σg²
clipScale = min(1, max_grad_norm / max(gradNorm, 1e-6))

Single thread. Sonuç grad_stats[0..1] array'inde.

max_grad_norm host'tan uniform (örn 1.0). clipScale ≤ 1. Eğer norm < max_norm → clipScale=1 (no clipping). Aksi halde clipScale < 1 → clip applied.

3. adamw_update

Her thread bir element için:

1. read grad, finite-check, multiply by clip_scale
2. update m, v
3. apply bias correction
4. compute weight update
5. write w_new, m_new, v_new
6. zero grad (for next iter)

adamw_update mega_w/g/m/v üzerinde çalışır. Sub-range binding ile tek dispatch tüm tensorları update ediyor.


Bind Group ABI

adamw_update (8 binding — Saf fp32)

BindingTürDetay
0storage, read_writew: array<f32> — master weights
1storage, read_writeg: array<f32> — gradients (zero'd after)
2storage, read_writem_buf: array<f32> — 1st moment
3storage, read_writev_buf: array<f32> — 2nd moment
4storage, readgrad_stats: array<f32> — [gradNorm, clipScale]
5uniformsize: u32
6uniformhp: vec4<f32> — (lr, beta1, beta2, wd)
7uniformbias: vec4<f32> — (beta1_t, beta2_t, eps, _)

adamw_update_f16 (9 binding — Mixed Precision / Fused Mirroring)

BindingTürDetay
0storage, read_writew_mp: array<f32> — master weights
1storage, read_writeg_mp: array<f32> — gradients (zero'd after)
2storage, read_writem_buf_mp: array<f32> — 1st moment
3storage, read_writev_buf_mp: array<f32> — 2nd moment
4storage, readgrad_stats_mp: array<f32> — [gradNorm, clipScale]
5uniformsize_mp: u32
6uniformhp_mp: vec4<f32>
7uniformbias_mp: vec4<f32>
8storage, read_writedst_w16: array<f16> — fp16 formatında forward ağırlık aynası (fused mirror)

adamw_8bit_update (10 binding — 8-bit Quantized State)

BindingTürDetay
0storage, read_writew: array<f32> — master weights
1storage, read_writeg: array<f32> — gradients (zero'd after)
2storage, read_writem_packed: array<u32> — packed int8 1st moment (4 values per u32)
3storage, read_writev_packed: array<u32> — packed int8 2nd moment (4 values per u32)
4storage, read_writem_scale: array<f32> — per-block (256 elements) scale factor
5storage, read_writev_scale: array<f32> — per-block (256 elements) scale factor
6storage, readgrad_stats: array<f32> — [gradNorm, clipScale]
7uniformsize: u32
8uniformhp: vec4<f32> — (lr, b1, b2, wd)
9uniformbias: vec4<f32> — (b1_t, b2_t, eps, _)

beta1_t = beta1^step, beta2_t = beta2^step host'ta hesaplanır.


Dispatch Şekli

reduce_norm_sq

workgroup_size: 256
grid: ceil(total_params / 256)

Total parametre sayısı — örn 116M. Her thread strided loop ile element işler, workgroup tree-reduction, atomic-add.

adamw_update

workgroup_size: 256
grid: ceil(size / 256)

size mega-buffer slice'ın eleman sayısı. Multi-tensor pattern:

  1. Compute beta_t for current step (host)
  2. Bind mega_w/g/m/v sub-range for "wd block" (decay applied)
  3. Dispatch adamw_update for wd block
  4. Bind sub-range for "no-wd block" (norms, biases — wd skipped)
  5. Dispatch adamw_update for no-wd block

Toplam 2 dispatch yerine 74 (per-tensor).


Satır Satır — adamw_update

1) Setup

wgsl
let i = flat_id(gid, nwg);
if (i >= size) { return; }

let lr      = hp.x;
let beta1   = hp.y;
let beta2   = hp.z;
let wd      = hp.w;
let beta1_t = bias.x;
let beta2_t = bias.y;
let eps     = bias.z;
let clip    = grad_stats[1];

Tüm hyperparameters uniform'dan gelir. clip runtime'da hesaplandı (önceki kernel'da).

2) Grad finite-check + clip

wgsl
var grad = g[i];
grad = select(0.0, grad, is_finite(grad));
grad = grad * clip;

NaN/Inf grad → 0'a çek (gradient explosion'dan korunma). Sonra clip ile multiply (norm'u max_norm'un altına).

3) Adam moments

wgsl
let m_old = m_buf[i];
let v_old = v_buf[i];
let w_old = w[i];

let m_new = beta1 * m_old + (1.0 - beta1) * grad;
let v_new = beta2 * v_old + (1.0 - beta2) * grad * grad;

EMA updates:

  • m: gradient'ın exponential moving average (1st moment, "velocity")
  • v: gradient²'nin EMA (2nd moment, "variance")

beta1 ≈ 0.9, beta2 ≈ 0.999 default. m hızlı reaktif (recent grads), v yavaş (long-term grad²).

4) Bias correction

wgsl
let b1_corr = 1.0 / max(1.0 - beta1_t, 1e-12);
let b2_corr = 1.0 / max(1.0 - beta2_t, 1e-12);
let m_hat = m_new * b1_corr;
let v_hat = v_new * b2_corr;

Erken step'lerde m, v ≈ 0 (zero-init'ten geliyor). Bias correction bunu compansate eder. Step ilerledikçe beta_t → 0, correction → 1.

5) Update + decoupled WD

wgsl
let denom = sqrt(v_hat) + eps;
let step  = m_hat / denom + wd * w_old;
let w_new = w_old - lr * step;

wd * w_old — decoupled weight decay. Adam moments-dependent değil.

6) Write back

wgsl
w[i] = w_new;
m_buf[i] = m_new;
v_buf[i] = v_new;
g[i] = 0.0;     ← grad zero for next step

Grad zero fused. Ayrı fill_zero kernel'ı çağrısı yok. Bandwidth tasarrufu.


reduce_norm_sq — Atomic Reduction

wgsl
let grid_size = nwg.x * nwg.y * WG;
var local: f32 = 0.0;
var i = flat_id(gid, nwg);
loop {
    if (i >= n) { break; }
    let v = data[i];
    if (is_finite(v)) { local = local + v * v; }
    i = i + grid_size;
}

sh_red[tid] = local;
workgroupBarrier();
// classical tree reduction
var s = WG / 2u;
loop {
    if (s == 0u) { break; }
    if (tid < s) { sh_red[tid] = sh_red[tid] + sh_red[tid + s]; }
    workgroupBarrier();
    s = s >> 1u;
}

if (tid == 0u) {
    // CAS atomic add to global accumulator
    ...
}

Niye wg_reduce_sum değil tree reduction?

wg_reduce_sum subgroup operations gerektirir. Tüm GPU'larda subgroup destek yok (eski browsers). Tree reduction subgroups extension'a bağımlı değil — portable. Optimizer kernel mobile/older hardware için tasarlandı.

Hot path değil — global norm sadece 1 kez per step. Ekstra barrier'ları kabul ediyoruz.


adamw_8bit_update

Production'da kullanılmayan alternative kernel. Adam state'i (m, v) packed u32 quantized + scale factors olarak tutar:

  • Memory savings: m + v 8-bit ile ~%75 azalır (4 bytes → 1 byte per state)
  • Quality cost: ~%1 worse final loss

8-bit dynamic quantization — block-wide max → scale factor → quantize 8-bit signed. Detay 12_optimizer.wgsl'da ~80+ line kernel.

Bizim default fp32 AdamW. 8-bit fallback gerektiğinde (memory budget düşük) UI'dan seçilebilir.


WGSL-Spesifik Notlar

1. Sub-range binding

WebGPU bindGroupLayout standard. entries[].resource.{buffer, offset, size} ile bir buffer'ın slice'ı bind edilebilir. Bizim engine bunu kullanarak mega-buffer üzerinde per-param view oluşturuyor.

2. atomic<u32> declaration

wgsl
@group(0) @binding(1) var<storage, read_write> result: atomic<u32>;

Reduce kernel'da global accumulator atomic-tagged. Per-parameter scalar (1 element).

3. nwg.x * nwg.y * WG grid size

wgsl
let grid_size = nwg.x * nwg.y * WG;

2D grid'i lineerize ediyor (1D fallback için flat_id kullanımı). Total thread sayısı.

4. select(0.0, grad, is_finite(grad))

WGSL ternary. is_finite true ise grad, değilse 0. Standard NaN/Inf guard.


Performance — Profile snapshot'tan

adamw_update            (× 2)  35.78 ms   → 6.9% step
scale                   (× 1)   8.91 ms   → 1.7% step (gradient clip apply)
cast_f32_to_f16         (×110)  8.59 ms   → 1.7% step (mixed precision sync)
reduce_norm_sq          (× 1)   4.26 ms   → 0.8% step
finalize_grad_stats     (× 1)   0.07 ms
                                         ───────
                                Total ~57 ms = 11% step

adamw_update 35.8 ms / 2 calls = 17.9 ms per call. Yüksek. Multi-tensor AdamW'ın SUB-LINEAR olması beklenirdi ama 17.9 ms her call yine de büyük (mega_w 100M+ element üzerinde dispatch).

Tamamlanan Optimizasyon (perf-doc #B): cast_f32_to_f16 standalone sweep'i tamamen ekarte edilerek adamw_update_f16 kernel'ı ile optimizer içerisine gömülmüştür. Bu sayede katman başına parametreler üzerinde fazladan koşan ~110 dispatch engellenerek ~%1.7 step süresi kazancı doğrudan elde edilmiştir.


Code Review

Bulgu 1: 2 dispatch yerine 1 olamaz mı?

RiskAçıklama
🟡 minorŞu an wd-block + no-wd-block ayrı dispatch. Tek dispatch'te select(wd, 0.0, is_norm_or_bias) yapılabilir. Ek register/uniform gerekir, marjinal speedup.

Bulgu 2: Bias correction precision

RiskAçıklama
🟢 minor1 - beta_t küçük olunca 1/(1-beta_t) büyük olur. Erken step'lerde precision loss riski. max(..., 1e-12) clamp guard.

Bulgu 3: Grad zero fused — efficient

RiskAçıklama
🟢 yokg[i] = 0.0 end-of-kernel. Ayrı fill_zero call gereği yok.

Hızlı Kontrol Listesi

Test SenaryosuDurum
Adam moments default zero-init mi?✅ allocate'ta fill_zero
step = 1 bias correction sonsuz değil mi?✅ max(1e-12) clamp
NaN grad → 0 cleanup?✅ select
Mega-buffer sub-range binding doğru mu?✅ runtime test passing
wd_block vs no_wd_block ayrımı doğru mu?✅ host-side param layout
8-bit kernel parity ile fp32 mı?⚠ %1-2 worse final loss accepted

Pipeline Tamamlandı

Bu son chapter. Forward → Backward → Optimizer akışının her kernel'i incelendi. Aynı corpus üzerinde her training step şu sırayla çalışır:

01 Embed → 02 Norm → 03 Linear (Q/K/V/O) → 04 RoPE → 05 Attention
                  → 02 Norm → 03 Linear (gate/up) → 06 SwiGLU → 03 Linear (down) → residual
... 12 layer ...
                  → 02 Norm → 03 Linear (lm_head) → 07 Loss + dLogits

← 11 Linear bwd ← 09 FFN bwd ← 10 Attn bwd ← 02 Norm bwd ← 01 Embed bwd ...

→ reduce_norm_sq → finalize_grad_stats → adamw_update_f16 (fused) → next step

Her parça kendi md dokümanında. Geri dön index'e: index.md.

WGSL kernel etüdleri · WebGPU üzerinde sıfırdan LLMİstanbul’da Uğur Toprakdeviren tarafından hazırlandı.