llm.istanbul·Study
TR EN
Workbench →

adamw_update, reduce_norm_sq, finalize_grad_stats — AdamW Optimizer

File: 12_optimizer.wgsl Pipeline step: The final step of the pipeline. Updates the weights after backward finishes.

5 kernels:

  • reduce_norm_sqΣ g² global gradient norm reduction
  • finalize_grad_statsgradNorm = √Σ g², clipScale = min(1, max_norm / gradNorm)
  • adamw_update — fused AdamW + clip + grad zero (pure fp32)
  • adamw_update_f16 — fused AdamW + clip + grad zero + f16 forward mirror cast
  • adamw_8bit_update — 8-bit quantized state variant (for memory savings)

Additionally, our WGSL has multi-tensor support — a single kernel call operates over the mega-buffer, addressing each parameter through sub-range bindings.


Wait, what is this?

After all the forward + backward noise, you're left with exactly one thing per weight: a gradient — an arrow saying "nudge this number in that direction and the loss drops a bit". The optimizer's only job is to take those arrows and actually push the weights. The dumbest version is w = w - lr * grad: take a tiny step in the direction the arrow points. That's SGD, and it works, but it's like walking a mountain trail blindfolded — you zigzag on every step.

AdamW gives that walk two kinds of memory. Picture rolling a ball down a slope: the ball doesn't blindly obey a single gradient arrow, it carries momentum — it looks at which way it's been heading over the last few steps (that's m). Second, it tracks how jumpy each parameter's gradient is (that's v): a weight that keeps thrashing left and right gets cautious, small steps, while a calm, steady one gets generous ones. Both are "exponential moving averages" — running averages that slowly forget old data, like a "7-day average" readout in some app.

There's one more subtlety: weight decay (gently pulling weights toward zero to rein in overfitting). The "W" in AdamW is exactly this — it doesn't blend that pull into the gradient math, it applies it as a separate, clean term right at the end ("decoupled"). That blending was the mistake old Adam made; AdamW fixes it.

Everything else is bookkeeping and speed. The model has ~74 separate weight tensors; instead of firing a kernel at each one, we lay them all side by side in one giant buffer (multi-tensor), and we refresh the fp16 mirror in the very same pass to dodge an extra full sweep (fused mirroring). The real fiddly work starts below.


Multi-tensor AdamW — The Most Important Optimization

Classic (per-tensor) AdamW:

for each parameter tensor (~74 in our model):
    adamw_update kernel call
74 dispatch / step → enormous overhead

Our multi-tensor approach:

1 kernel call → mega_w/g/m/v contiguous buffers
sub-range bindings: per-param view via offset+size

Mega-buffer:

  • mega_w: array<f32> — ALL weights concatenated into a single buffer
  • mega_g, mega_m, mega_v — same layout

Sub-range binding:

javascript
{ buffer: mega_w, offset: param_offset[i], size: param_size[i] }

WebGPU sees this as a single param tensor, but physically it is a slice of the mega-buffer.

Result:

  • 74 → 2 dispatches (wd block + no-wd block)
  • 50 ms → 13.7 ms optimizer phase
  • −36 ms / step

AdamW Math (Decoupled Weight Decay)

Classic AdamW:

m  ← β₁·m + (1-β₁)·g·clip
v  ← β₂·v + (1-β₂)·(g·clip)²
m̂  = m / (1 - β₁ᵗ)        ← bias correction
v̂  = v / (1 - β₂ᵗ)
w  ← w - lr·( m̂ / (√v̂ + ε) + λ·w )
                          ↑ decoupled weight decay

g·clip = grad × clipScale (gradient clipping fused). λ (weight_decay) is decoupled — it does not mix into the momentum-driven update, it is a separate term.

Why decoupled? Classic Adam (λ·w included in momentum) updates the weight like momentum too — that is not "decay" but "shrinkage". AdamW (Loshchilov 2017) fixes this mistake: λ·w is subtracted directly from the weight.


Pipeline — 3 Steps

1. reduce_norm_sq — Global gradient norm

input:  mega_g (all gradients concatenated)
output: norm_sq scalar (atomic accumulator)

for each gradient element:
  if finite: accumulate g²
return Σ g²

Atomic CAS-add into a scalar slot. A single scalar — no matter how contended, it is a tiny buffer.

2. finalize_grad_stats

gradNorm  = √Σg²
clipScale = min(1, max_grad_norm / max(gradNorm, 1e-6))

Single thread. The result lives in the grad_stats[0..1] array.

max_grad_norm is a uniform from the host (e.g. 1.0). clipScale ≤ 1. If norm < max_norm → clipScale=1 (no clipping). Otherwise clipScale < 1 → clip applied.

3. adamw_update

Each thread, for one element:

1. read grad, finite-check, multiply by clip_scale
2. update m, v
3. apply bias correction
4. compute weight update
5. write w_new, m_new, v_new
6. zero grad (for next iter)

adamw_update operates over mega_w/g/m/v. With sub-range binding, a single dispatch updates all tensors.


Bind Group ABI

adamw_update (8 bindings — Pure fp32)

BindingTypeDetail
0storage, read_writew: array<f32> — master weights
1storage, read_writeg: array<f32> — gradients (zero'd after)
2storage, read_writem_buf: array<f32> — 1st moment
3storage, read_writev_buf: array<f32> — 2nd moment
4storage, readgrad_stats: array<f32> — [gradNorm, clipScale]
5uniformsize: u32
6uniformhp: vec4<f32> — (lr, beta1, beta2, wd)
7uniformbias: vec4<f32> — (beta1_t, beta2_t, eps, _)

adamw_update_f16 (9 bindings — Mixed Precision / Fused Mirroring)

BindingTypeDetail
0storage, read_writew_mp: array<f32> — master weights
1storage, read_writeg_mp: array<f32> — gradients (zero'd after)
2storage, read_writem_buf_mp: array<f32> — 1st moment
3storage, read_writev_buf_mp: array<f32> — 2nd moment
4storage, readgrad_stats_mp: array<f32> — [gradNorm, clipScale]
5uniformsize_mp: u32
6uniformhp_mp: vec4<f32>
7uniformbias_mp: vec4<f32>
8storage, read_writedst_w16: array<f16> — forward weight mirror in fp16 format (fused mirror)

adamw_8bit_update (10 bindings — 8-bit Quantized State)

BindingTypeDetail
0storage, read_writew: array<f32> — master weights
1storage, read_writeg: array<f32> — gradients (zero'd after)
2storage, read_writem_packed: array<u32> — packed int8 1st moment (4 values per u32)
3storage, read_writev_packed: array<u32> — packed int8 2nd moment (4 values per u32)
4storage, read_writem_scale: array<f32> — per-block (256 elements) scale factor
5storage, read_writev_scale: array<f32> — per-block (256 elements) scale factor
6storage, readgrad_stats: array<f32> — [gradNorm, clipScale]
7uniformsize: u32
8uniformhp: vec4<f32> — (lr, b1, b2, wd)
9uniformbias: vec4<f32> — (b1_t, b2_t, eps, _)

beta1_t = beta1^step, beta2_t = beta2^step are computed on the host.


Dispatch Shape

reduce_norm_sq

workgroup_size: 256
grid: ceil(total_params / 256)

Total parameter count — e.g. 116M. Each thread processes elements with a strided loop, then a workgroup tree-reduction, then atomic-add.

adamw_update

workgroup_size: 256
grid: ceil(size / 256)

size is the element count of the mega-buffer slice. Multi-tensor pattern:

  1. Compute beta_t for the current step (host)
  2. Bind mega_w/g/m/v sub-range for the "wd block" (decay applied)
  3. Dispatch adamw_update for the wd block
  4. Bind sub-range for the "no-wd block" (norms, biases — wd skipped)
  5. Dispatch adamw_update for the no-wd block

A total of 2 dispatches instead of 74 (per-tensor).


Line by Line — adamw_update

1) Setup

wgsl
let i = flat_id(gid, nwg);
if (i >= size) { return; }

let lr      = hp.x;
let beta1   = hp.y;
let beta2   = hp.z;
let wd      = hp.w;
let beta1_t = bias.x;
let beta2_t = bias.y;
let eps     = bias.z;
let clip    = grad_stats[1];

All hyperparameters come from uniforms. clip was computed at runtime (in the previous kernel).

2) Grad finite-check + clip

wgsl
var grad = g[i];
grad = select(0.0, grad, is_finite(grad));
grad = grad * clip;

NaN/Inf grad → pull to 0 (protection against gradient explosion). Then multiply by clip (to bring the norm below max_norm).

3) Adam moments

wgsl
let m_old = m_buf[i];
let v_old = v_buf[i];
let w_old = w[i];

let m_new = beta1 * m_old + (1.0 - beta1) * grad;
let v_new = beta2 * v_old + (1.0 - beta2) * grad * grad;

EMA updates:

  • m: the exponential moving average of the gradient (1st moment, "velocity")
  • v: the EMA of gradient² (2nd moment, "variance")

beta1 ≈ 0.9, beta2 ≈ 0.999 by default. m is fast-reacting (recent grads), v is slow (long-term grad²).

4) Bias correction

wgsl
let b1_corr = 1.0 / max(1.0 - beta1_t, 1e-12);
let b2_corr = 1.0 / max(1.0 - beta2_t, 1e-12);
let m_hat = m_new * b1_corr;
let v_hat = v_new * b2_corr;

In early steps m, v ≈ 0 (coming from the zero-init). Bias correction compensates for this. As the step advances, beta_t → 0 and correction → 1.

5) Update + decoupled WD

wgsl
let denom = sqrt(v_hat) + eps;
let step  = m_hat / denom + wd * w_old;
let w_new = w_old - lr * step;

wd * w_old — decoupled weight decay. It does not depend on the Adam moments.

6) Write back

wgsl
w[i] = w_new;
m_buf[i] = m_new;
v_buf[i] = v_new;
g[i] = 0.0;     ← grad zero for next step

Grad zero fused. There is no separate fill_zero kernel call. Bandwidth savings.


reduce_norm_sq — Atomic Reduction

wgsl
let grid_size = nwg.x * nwg.y * WG;
var local: f32 = 0.0;
var i = flat_id(gid, nwg);
loop {
    if (i >= n) { break; }
    let v = data[i];
    if (is_finite(v)) { local = local + v * v; }
    i = i + grid_size;
}

sh_red[tid] = local;
workgroupBarrier();
// classical tree reduction
var s = WG / 2u;
loop {
    if (s == 0u) { break; }
    if (tid < s) { sh_red[tid] = sh_red[tid] + sh_red[tid + s]; }
    workgroupBarrier();
    s = s >> 1u;
}

if (tid == 0u) {
    // CAS atomic add to global accumulator
    ...
}

Why tree reduction instead of wg_reduce_sum?

wg_reduce_sum requires subgroup operations. Not all GPUs support subgroups (older browsers). Tree reduction does not depend on the subgroups extension — it is portable. The optimizer kernel was designed for mobile/older hardware.

This is not a hot path — the global norm is needed only once per step. We accept the extra barriers.


adamw_8bit_update

An alternative kernel not used in production. It keeps the Adam state (m, v) as packed u32 quantized + scale factors:

  • Memory savings: m + v shrink by ~75% with 8-bit (4 bytes → 1 byte per state)
  • Quality cost: ~1% worse final loss

8-bit dynamic quantization — block-wide max → scale factor → quantize to 8-bit signed. Details are in the ~80+ line kernel in 12_optimizer.wgsl.

Our default is fp32 AdamW. When the 8-bit fallback is needed (low memory budget) it can be selected from the UI.


WGSL-Specific Notes

1. Sub-range binding

WebGPU bindGroupLayout standard. With entries[].resource.{buffer, offset, size} a slice of a buffer can be bound. Our engine uses this to form a per-param view over the mega-buffer.

2. atomic<u32> declaration

wgsl
@group(0) @binding(1) var<storage, read_write> result: atomic<u32>;

In the reduce kernel the global accumulator is atomic-tagged. A per-parameter scalar (1 element).

3. nwg.x * nwg.y * WG grid size

wgsl
let grid_size = nwg.x * nwg.y * WG;

Linearizes the 2D grid (using flat_id for the 1D fallback). Total thread count.

4. select(0.0, grad, is_finite(grad))

WGSL ternary. If is_finite is true, grad, otherwise 0. Standard NaN/Inf guard.


Performance — From a Profile Snapshot

adamw_update            (× 2)  35.78 ms   → 6.9% step
scale                   (× 1)   8.91 ms   → 1.7% step (gradient clip apply)
cast_f32_to_f16         (×110)  8.59 ms   → 1.7% step (mixed precision sync)
reduce_norm_sq          (× 1)   4.26 ms   → 0.8% step
finalize_grad_stats     (× 1)   0.07 ms
                                         ───────
                                Total ~57 ms = 11% step

adamw_update is 35.8 ms / 2 calls = 17.9 ms per call. High. Multi-tensor AdamW was expected to be SUB-LINEAR, but 17.9 ms per call is still large (a dispatch over mega_w with 100M+ elements).

Completed Optimization (perf-doc #B): The standalone cast_f32_to_f16 sweep was eliminated entirely and embedded into the optimizer via the adamw_update_f16 kernel. This avoids the ~110 extra dispatches running over the per-layer parameters, directly yielding a ~1.7% step-time gain.


Code Review

Finding 1: Could it be 1 dispatch instead of 2?

RiskDescription
🟡 minorRight now wd-block + no-wd-block are separate dispatches. In a single dispatch one could do select(wd, 0.0, is_norm_or_bias). It needs an extra register/uniform, marginal speedup.

Finding 2: Bias correction precision

RiskDescription
🟢 minorWhen 1 - beta_t is small, 1/(1-beta_t) becomes large. Risk of precision loss in early steps. max(..., 1e-12) clamp guard.

Finding 3: Grad zero fused — efficient

RiskDescription
🟢 noneg[i] = 0.0 at end-of-kernel. No need for a separate fill_zero call.

Quick Checklist

Test ScenarioStatus
Are Adam moments default zero-init?✅ fill_zero at allocation
At step = 1, is bias correction not infinite?✅ max(1e-12) clamp
NaN grad → 0 cleanup?✅ select
Is the mega-buffer sub-range binding correct?✅ runtime test passing
Is the wd_block vs no_wd_block split correct?✅ host-side param layout
Is the 8-bit kernel at parity with fp32?⚠ 1-2% worse final loss accepted

Pipeline Complete

This is the last chapter. Every kernel of the Forward → Backward → Optimizer flow has been examined. On the same corpus, each training step runs in the following order:

01 Embed → 02 Norm → 03 Linear (Q/K/V/O) → 04 RoPE → 05 Attention
                  → 02 Norm → 03 Linear (gate/up) → 06 SwiGLU → 03 Linear (down) → residual
... 12 layer ...
                  → 02 Norm → 03 Linear (lm_head) → 07 Loss + dLogits

← 11 Linear bwd ← 09 FFN bwd ← 10 Attn bwd ← 02 Norm bwd ← 01 Embed bwd ...

→ reduce_norm_sq → finalize_grad_stats → adamw_update_f16 (fused) → next step

Each piece is in its own md document. Go back to the index: index.md.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.