adamw_update, reduce_norm_sq, finalize_grad_stats — AdamW Optimizer
File: 12_optimizer.wgsl Pipeline step: The final step of the pipeline. Updates the weights after backward finishes.
5 kernels:
reduce_norm_sq—Σ g²global gradient norm reductionfinalize_grad_stats—gradNorm = √Σ g²,clipScale = min(1, max_norm / gradNorm)adamw_update— fused AdamW + clip + grad zero (pure fp32)adamw_update_f16— fused AdamW + clip + grad zero + f16 forward mirror castadamw_8bit_update— 8-bit quantized state variant (for memory savings)
Additionally, our WGSL has multi-tensor support — a single kernel call operates over the mega-buffer, addressing each parameter through sub-range bindings.
Wait, what is this?
After all the forward + backward noise, you're left with exactly one thing per weight: a gradient — an arrow saying "nudge this number in that direction and the loss drops a bit". The optimizer's only job is to take those arrows and actually push the weights. The dumbest version is w = w - lr * grad: take a tiny step in the direction the arrow points. That's SGD, and it works, but it's like walking a mountain trail blindfolded — you zigzag on every step.
AdamW gives that walk two kinds of memory. Picture rolling a ball down a slope: the ball doesn't blindly obey a single gradient arrow, it carries momentum — it looks at which way it's been heading over the last few steps (that's m). Second, it tracks how jumpy each parameter's gradient is (that's v): a weight that keeps thrashing left and right gets cautious, small steps, while a calm, steady one gets generous ones. Both are "exponential moving averages" — running averages that slowly forget old data, like a "7-day average" readout in some app.
There's one more subtlety: weight decay (gently pulling weights toward zero to rein in overfitting). The "W" in AdamW is exactly this — it doesn't blend that pull into the gradient math, it applies it as a separate, clean term right at the end ("decoupled"). That blending was the mistake old Adam made; AdamW fixes it.
Everything else is bookkeeping and speed. The model has ~74 separate weight tensors; instead of firing a kernel at each one, we lay them all side by side in one giant buffer (multi-tensor), and we refresh the fp16 mirror in the very same pass to dodge an extra full sweep (fused mirroring). The real fiddly work starts below.
Multi-tensor AdamW — The Most Important Optimization
Classic (per-tensor) AdamW:
for each parameter tensor (~74 in our model):
adamw_update kernel call
74 dispatch / step → enormous overheadOur multi-tensor approach:
1 kernel call → mega_w/g/m/v contiguous buffers
sub-range bindings: per-param view via offset+sizeMega-buffer:
mega_w: array<f32>— ALL weights concatenated into a single buffermega_g,mega_m,mega_v— same layout
Sub-range binding:
{ buffer: mega_w, offset: param_offset[i], size: param_size[i] }WebGPU sees this as a single param tensor, but physically it is a slice of the mega-buffer.
Result:
- 74 → 2 dispatches (wd block + no-wd block)
- 50 ms → 13.7 ms optimizer phase
- −36 ms / step
AdamW Math (Decoupled Weight Decay)
Classic AdamW:
m ← β₁·m + (1-β₁)·g·clip
v ← β₂·v + (1-β₂)·(g·clip)²
m̂ = m / (1 - β₁ᵗ) ← bias correction
v̂ = v / (1 - β₂ᵗ)
w ← w - lr·( m̂ / (√v̂ + ε) + λ·w )
↑ decoupled weight decayg·clip = grad × clipScale (gradient clipping fused).
λ (weight_decay) is decoupled — it does not mix into the momentum-driven update, it is a separate term.
Why decoupled? Classic Adam (λ·w included in momentum) updates the weight like momentum too — that is not "decay" but "shrinkage". AdamW (Loshchilov 2017) fixes this mistake: λ·w is subtracted directly from the weight.
Pipeline — 3 Steps
1. reduce_norm_sq — Global gradient norm
input: mega_g (all gradients concatenated)
output: norm_sq scalar (atomic accumulator)
for each gradient element:
if finite: accumulate g²
return Σ g²Atomic CAS-add into a scalar slot. A single scalar — no matter how contended, it is a tiny buffer.
2. finalize_grad_stats
gradNorm = √Σg²
clipScale = min(1, max_grad_norm / max(gradNorm, 1e-6))Single thread. The result lives in the grad_stats[0..1] array.
max_grad_norm is a uniform from the host (e.g. 1.0). clipScale ≤ 1. If norm < max_norm → clipScale=1 (no clipping). Otherwise clipScale < 1 → clip applied.
3. adamw_update
Each thread, for one element:
1. read grad, finite-check, multiply by clip_scale
2. update m, v
3. apply bias correction
4. compute weight update
5. write w_new, m_new, v_new
6. zero grad (for next iter)adamw_update operates over mega_w/g/m/v. With sub-range binding, a single dispatch updates all tensors.
Bind Group ABI
adamw_update (8 bindings — Pure fp32)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | w: array<f32> — master weights |
| 1 | storage, read_write | g: array<f32> — gradients (zero'd after) |
| 2 | storage, read_write | m_buf: array<f32> — 1st moment |
| 3 | storage, read_write | v_buf: array<f32> — 2nd moment |
| 4 | storage, read | grad_stats: array<f32> — [gradNorm, clipScale] |
| 5 | uniform | size: u32 |
| 6 | uniform | hp: vec4<f32> — (lr, beta1, beta2, wd) |
| 7 | uniform | bias: vec4<f32> — (beta1_t, beta2_t, eps, _) |
adamw_update_f16 (9 bindings — Mixed Precision / Fused Mirroring)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | w_mp: array<f32> — master weights |
| 1 | storage, read_write | g_mp: array<f32> — gradients (zero'd after) |
| 2 | storage, read_write | m_buf_mp: array<f32> — 1st moment |
| 3 | storage, read_write | v_buf_mp: array<f32> — 2nd moment |
| 4 | storage, read | grad_stats_mp: array<f32> — [gradNorm, clipScale] |
| 5 | uniform | size_mp: u32 |
| 6 | uniform | hp_mp: vec4<f32> |
| 7 | uniform | bias_mp: vec4<f32> |
| 8 | storage, read_write | dst_w16: array<f16> — forward weight mirror in fp16 format (fused mirror) |
adamw_8bit_update (10 bindings — 8-bit Quantized State)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | w: array<f32> — master weights |
| 1 | storage, read_write | g: array<f32> — gradients (zero'd after) |
| 2 | storage, read_write | m_packed: array<u32> — packed int8 1st moment (4 values per u32) |
| 3 | storage, read_write | v_packed: array<u32> — packed int8 2nd moment (4 values per u32) |
| 4 | storage, read_write | m_scale: array<f32> — per-block (256 elements) scale factor |
| 5 | storage, read_write | v_scale: array<f32> — per-block (256 elements) scale factor |
| 6 | storage, read | grad_stats: array<f32> — [gradNorm, clipScale] |
| 7 | uniform | size: u32 |
| 8 | uniform | hp: vec4<f32> — (lr, b1, b2, wd) |
| 9 | uniform | bias: vec4<f32> — (b1_t, b2_t, eps, _) |
beta1_t = beta1^step, beta2_t = beta2^step are computed on the host.
Dispatch Shape
reduce_norm_sq
workgroup_size: 256
grid: ceil(total_params / 256)Total parameter count — e.g. 116M. Each thread processes elements with a strided loop, then a workgroup tree-reduction, then atomic-add.
adamw_update
workgroup_size: 256
grid: ceil(size / 256)size is the element count of the mega-buffer slice. Multi-tensor pattern:
- Compute beta_t for the current step (host)
- Bind mega_w/g/m/v sub-range for the "wd block" (decay applied)
- Dispatch
adamw_updatefor the wd block - Bind sub-range for the "no-wd block" (norms, biases — wd skipped)
- Dispatch
adamw_updatefor the no-wd block
A total of 2 dispatches instead of 74 (per-tensor).
Line by Line — adamw_update
1) Setup
let i = flat_id(gid, nwg);
if (i >= size) { return; }
let lr = hp.x;
let beta1 = hp.y;
let beta2 = hp.z;
let wd = hp.w;
let beta1_t = bias.x;
let beta2_t = bias.y;
let eps = bias.z;
let clip = grad_stats[1];All hyperparameters come from uniforms. clip was computed at runtime (in the previous kernel).
2) Grad finite-check + clip
var grad = g[i];
grad = select(0.0, grad, is_finite(grad));
grad = grad * clip;NaN/Inf grad → pull to 0 (protection against gradient explosion). Then multiply by clip (to bring the norm below max_norm).
3) Adam moments
let m_old = m_buf[i];
let v_old = v_buf[i];
let w_old = w[i];
let m_new = beta1 * m_old + (1.0 - beta1) * grad;
let v_new = beta2 * v_old + (1.0 - beta2) * grad * grad;EMA updates:
m: the exponential moving average of the gradient (1st moment, "velocity")v: the EMA of gradient² (2nd moment, "variance")
beta1 ≈ 0.9, beta2 ≈ 0.999 by default. m is fast-reacting (recent grads), v is slow (long-term grad²).
4) Bias correction
let b1_corr = 1.0 / max(1.0 - beta1_t, 1e-12);
let b2_corr = 1.0 / max(1.0 - beta2_t, 1e-12);
let m_hat = m_new * b1_corr;
let v_hat = v_new * b2_corr;In early steps m, v ≈ 0 (coming from the zero-init). Bias correction compensates for this. As the step advances, beta_t → 0 and correction → 1.
5) Update + decoupled WD
let denom = sqrt(v_hat) + eps;
let step = m_hat / denom + wd * w_old;
let w_new = w_old - lr * step;wd * w_old — decoupled weight decay. It does not depend on the Adam moments.
6) Write back
w[i] = w_new;
m_buf[i] = m_new;
v_buf[i] = v_new;
g[i] = 0.0; ← grad zero for next stepGrad zero fused. There is no separate fill_zero kernel call. Bandwidth savings.
reduce_norm_sq — Atomic Reduction
let grid_size = nwg.x * nwg.y * WG;
var local: f32 = 0.0;
var i = flat_id(gid, nwg);
loop {
if (i >= n) { break; }
let v = data[i];
if (is_finite(v)) { local = local + v * v; }
i = i + grid_size;
}
sh_red[tid] = local;
workgroupBarrier();
// classical tree reduction
var s = WG / 2u;
loop {
if (s == 0u) { break; }
if (tid < s) { sh_red[tid] = sh_red[tid] + sh_red[tid + s]; }
workgroupBarrier();
s = s >> 1u;
}
if (tid == 0u) {
// CAS atomic add to global accumulator
...
}Why tree reduction instead of wg_reduce_sum?
wg_reduce_sum requires subgroup operations. Not all GPUs support subgroups (older browsers). Tree reduction does not depend on the subgroups extension — it is portable. The optimizer kernel was designed for mobile/older hardware.
This is not a hot path — the global norm is needed only once per step. We accept the extra barriers.
adamw_8bit_update
An alternative kernel not used in production. It keeps the Adam state (m, v) as packed u32 quantized + scale factors:
- Memory savings: m + v shrink by ~75% with 8-bit (4 bytes → 1 byte per state)
- Quality cost: ~1% worse final loss
8-bit dynamic quantization — block-wide max → scale factor → quantize to 8-bit signed. Details are in the ~80+ line kernel in 12_optimizer.wgsl.
Our default is fp32 AdamW. When the 8-bit fallback is needed (low memory budget) it can be selected from the UI.
WGSL-Specific Notes
1. Sub-range binding
WebGPU bindGroupLayout standard. With entries[].resource.{buffer, offset, size} a slice of a buffer can be bound. Our engine uses this to form a per-param view over the mega-buffer.
2. atomic<u32> declaration
@group(0) @binding(1) var<storage, read_write> result: atomic<u32>;In the reduce kernel the global accumulator is atomic-tagged. A per-parameter scalar (1 element).
3. nwg.x * nwg.y * WG grid size
let grid_size = nwg.x * nwg.y * WG;Linearizes the 2D grid (using flat_id for the 1D fallback). Total thread count.
4. select(0.0, grad, is_finite(grad))
WGSL ternary. If is_finite is true, grad, otherwise 0. Standard NaN/Inf guard.
Performance — From a Profile Snapshot
adamw_update (× 2) 35.78 ms → 6.9% step
scale (× 1) 8.91 ms → 1.7% step (gradient clip apply)
cast_f32_to_f16 (×110) 8.59 ms → 1.7% step (mixed precision sync)
reduce_norm_sq (× 1) 4.26 ms → 0.8% step
finalize_grad_stats (× 1) 0.07 ms
───────
Total ~57 ms = 11% stepadamw_update is 35.8 ms / 2 calls = 17.9 ms per call. High. Multi-tensor AdamW was expected to be SUB-LINEAR, but 17.9 ms per call is still large (a dispatch over mega_w with 100M+ elements).
Completed Optimization (perf-doc #B): The standalone cast_f32_to_f16 sweep was eliminated entirely and embedded into the optimizer via the adamw_update_f16 kernel. This avoids the ~110 extra dispatches running over the per-layer parameters, directly yielding a ~1.7% step-time gain.
Code Review
Finding 1: Could it be 1 dispatch instead of 2?
| Risk | Description |
|---|---|
| 🟡 minor | Right now wd-block + no-wd-block are separate dispatches. In a single dispatch one could do select(wd, 0.0, is_norm_or_bias). It needs an extra register/uniform, marginal speedup. |
Finding 2: Bias correction precision
| Risk | Description |
|---|---|
| 🟢 minor | When 1 - beta_t is small, 1/(1-beta_t) becomes large. Risk of precision loss in early steps. max(..., 1e-12) clamp guard. |
Finding 3: Grad zero fused — efficient
| Risk | Description |
|---|---|
| 🟢 none | g[i] = 0.0 at end-of-kernel. No need for a separate fill_zero call. |
Quick Checklist
| Test Scenario | Status |
|---|---|
| Are Adam moments default zero-init? | ✅ fill_zero at allocation |
At step = 1, is bias correction not infinite? | ✅ max(1e-12) clamp |
| NaN grad → 0 cleanup? | ✅ select |
| Is the mega-buffer sub-range binding correct? | ✅ runtime test passing |
Is the wd_block vs no_wd_block split correct? | ✅ host-side param layout |
| Is the 8-bit kernel at parity with fp32? | ⚠ 1-2% worse final loss accepted |
Pipeline Complete
This is the last chapter. Every kernel of the Forward → Backward → Optimizer flow has been examined. On the same corpus, each training step runs in the following order:
01 Embed → 02 Norm → 03 Linear (Q/K/V/O) → 04 RoPE → 05 Attention
→ 02 Norm → 03 Linear (gate/up) → 06 SwiGLU → 03 Linear (down) → residual
... 12 layer ...
→ 02 Norm → 03 Linear (lm_head) → 07 Loss + dLogits
← 11 Linear bwd ← 09 FFN bwd ← 10 Attn bwd ← 02 Norm bwd ← 01 Embed bwd ...
→ reduce_norm_sq → finalize_grad_stats → adamw_update_f16 (fused) → next stepEach piece is in its own md document. Go back to the index: index.md.