rmsnorm — Root-Mean-Square Normalization
File: 02_norm.wgsl Pipeline step: Twice per layer in the forward pass (before the input + post-attention residual, and before the FFN), and twice in the backward pass.
3 kernels: rmsnorm_forward, rmsnorm_backward, rmsnorm_forward_w16 (mixed precision variant).
Wait, what is this?
Picture a music playlist: one track blasts your ears off, the next is barely audible. Annoying, right? That's why players run "loudness normalization" — they pull every track's overall volume into line. But they don't touch the melody, the rhythm, or how the instruments sit relative to each other; they only scale the total loudness to a standard level.
RMSNorm does exactly that to the vectors flowing through the model. The "loudness" (magnitude) of a number vector coming out of a layer sometimes spikes, sometimes fades — and that imbalance makes training harder. RMSNorm takes each vector and divides it by its own RMS (root-mean-square, i.e. the square root of the mean of the squared values — a cousin of the standard deviation from stats). The result: the vector's direction/shape stays the same, but its magnitude settles onto a consistent scale.
There's one extra touch: after the division, we multiply by a w vector. That w is a learned gain knob — the model gets to say "keep this dimension a bit louder, turn that one down." So it's not a raw normalize; it's "normalize, then dial it back to whatever proportions you've learned."
You may have heard of LayerNorm — it does a similar job, but it first subtracts the vector's mean (centers the signal around zero). RMSNorm skips that step: no mean subtraction, just division by magnitude. Fewer ops, fewer parameters, and in practice (LLaMA, Mistral, Falcon) it does the same job. The genuinely interesting part is how you make this fast on a GPU without summing every element of a vector serially (the reduction) — we get to that below.
What Does It Do?
It divides the vector by its own RMS and multiplies by the learnable scalar w. A simple version of LayerNorm without mean subtraction — LLaMA's standard.
rms_i = sqrt((1/d) · Σ_k x_i[k]² + ε)
y_i[d] = x_i[d] / rms_i · w[d]x_i is a row vector (per token), each row processed independently. One workgroup = one row. A WG of 256 threads shares the row and performs the reduction.
The difference from LayerNorm:
- No mean subtraction (faster, less ALU)
- No bias (only scale γ, no shift β)
LayerNorm:
y = (x - μ) / σ · γ + βRMSNorm:
y = x / RMS · wLLaMA, Mistral, and Falcon all use RMSNorm. Consistent during pretraining, fewer parameters.
Mathematical Definition — Backward
L = loss. Given dy = ∂L/∂y, what we compute:
dx = ∂L/∂x— to pass back to the previous layerdw += ∂L/∂w— for the weight update (cross-row accumulation)
Derivative:
y[d] = x[d] · rms_inv · w[d]
where rms_inv = (mean(x²) + ε)^(-½)Chain rule:
∂y[d]/∂x[d'] = (δ[d=d'] · rms_inv + x[d] · ∂rms_inv/∂x[d']) · w[d]and ∂rms_inv/∂x[d'] = -rms_inv³ · x[d'] / D (D = d_model)
Result (for a single row):
A = Σ_k w[k] · x[k] · dy[k] ← scalar, workgroup reduction
dx[d] = w[d] · dy[d] · rms_inv
- x[d] · rms_inv³ · A / D ← coupling term
dw[d] += x[d] · rms_inv · dy[d] ← cross-row, atomic addA is computed for each row. dw is the sum over all rows (atomic accumulation).
Bind Group ABI
rmsnorm_forward (6 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | x: array<f32> — [seq × d_model] input |
| 1 | storage, read | w: array<f32> — [d_model] learnable scale |
| 2 | storage, read_write | y: array<f32> — [seq × d_model] output |
| 3 | storage, read_write | rms_out: array<f32> — [seq_len] saved for backward |
| 4 | uniform | dims: vec4<u32> — (seq_len, d_model, _, _) |
| 5 | uniform | params: vec4<f32> — (eps, _, _, _) |
Why is
rms_outsaved? To avoid recomputing it in the backward pass. The sum was taken during forward anyway; save the result and reuse it in backward. ~30% backward speedup.
rmsnorm_backward (7 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | x |
| 1 | storage, read | w |
| 2 | storage, read | dy — upstream gradient |
| 3 | storage, read | rms_in — saved rms_inv |
| 4 | storage, read_write | dx — gradient to next layer back |
| 5 | storage, read_write | dw: array<atomic<u32>> — bit-cast f32, scatter-add |
| 6 | uniform | dims |
rmsnorm_forward_w16 (6 bindings)
Same; only w_w16: array<f16> and an f32(w_w16[i]) cast in the computation.
Dispatch Shape
workgroup_size: 256
total threads: seq_len workgroups × 256 threadsOne workgroup = one row. Each WG performs the reduction for its own row; rows are independent (independent parallel).
Example (seq=512, d=768): 512 WG × 256 threads.
Why one WG per row? The reduction (sum of squares) is within the workgroup. Since no cross-row reduction is needed, rows parallelize easily. If
d_model > WG(e.g. d=2048, WG=256) → each thread processesd_model/WG = 8elements, a distributed reduction.
Line by Line — rmsnorm_forward
1) Setup
fn rmsnorm_forward(@builtin(workgroup_id) wgid: vec3<u32>,
@builtin(local_invocation_index) tid: u32) {
let row = wgid.x;
let seq_len = dims.x;
let d_model = dims.y;
let eps = params.x;
if (row >= seq_len) { return; }
let base = row * d_model;workgroup_id(the analog of Metal'sthreadgroup_position_in_grid) — each WG's ID within the grid.local_invocation_index(the analog of Metal'sthread_index_in_threadgroup) — the thread ID within the WG,0..255.base— this row's starting offset in the x/y arrays.
2) Phase 1 — Sum of squares
var local_sum = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
let v = x[base + i];
local_sum = local_sum + v * v;
}
let total = wg_reduce_sum(tid, local_sum);Strided loop:
- Thread
tidsums the elements that fall to its share:x[base + tid],x[base + tid + 256],x[base + tid + 512], ... - If
d_model = 768, each thread processes 3 elements (0/256/512, 1/257/513, ..., 255/511/767) - If
d_model = 256, each thread processes 1 element - If
d_model < 256, the last few threads never enter the loop (the loop conditiontid < d_modelis false on most threads)
local_sum is thread-private. wg_reduce_sum(tid, local_sum) → the sum across all 256 threads is returned to every thread (via subgroup reduction, detailed in 00_shared.md).
3) Phase 2 — Normalize and scale
let rms_inv = inverseSqrt(total / f32(d_model) + eps);
if (tid == 0u) { rms_out[row] = rms_inv; }
for (var i = tid; i < d_model; i = i + WG) {
y[base + i] = x[base + i] * rms_inv * w[i];
}inverseSqrt(...)→1/sqrt(...). A WGSL built-in, hardware fast-path (rsqrtinstruction).- If
tid == 0u, writerms_invinto the row's saved-state buffer. Other threads do not do this (single writer, no race). - Strided loop again:
y[base + i] = x[base + i] * rms_inv * w[i].
Why is
rms_invthe same on every thread? Becausewg_reduce_sumreturns the same total to all threads. TheninverseSqrt(same total + eps)→ the samerms_invon every thread. No race, deterministic.
Line by Line — rmsnorm_backward
1) Setup + reduce A
let row = wgid.x;
// ... (same)
let rms_inv = rms_in[row]; // load saved value
var local_a = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
local_a = local_a + w[i] * x[base + i] * dy[base + i];
}
let A = wg_reduce_sum(tid, local_a);A = Σ_k w[k]·x[k]·dy[k] is a workgroup-wide reduction.
2) coeff precompute
let coeff = rms_inv * rms_inv * rms_inv * A / f32(d_model);rms_inv³ · A / D — the coefficient of the coupling term in the backward formula. Computed by each thread (everything is a scalar uniform).
3) Per-element dx and atomic dw accumulation
for (var i = tid; i < d_model; i = i + WG) {
let xi = x[base + i];
let dyi = dy[base + i];
let wi = w[i];
dx[base + i] = wi * dyi * rms_inv - xi * coeff;
let dw_val = xi * rms_inv * dyi;
if (is_finite(dw_val) && dw_val != 0.0) {
var old_bits = atomicLoad(&dw[i]);
loop {
let new_bits = bitcast<u32>(bitcast<f32>(old_bits) + dw_val);
let res = atomicCompareExchangeWeak(&dw[i], old_bits, new_bits);
if (res.exchanged) { break; }
old_bits = res.old_value;
}
}
}Per-thread work:
- Strided loop, each thread processes its own
i dx[base + i]— each slot is touched by a single thread (no race, no atomic needed)dw[i]— the same slot across rows! If row 0's thread writes todw[i], row 1's thread also writes to it. That is why a CAS atomic add is needed.
The is_finite + != 0.0 check: A zero gradient is already the identity of the add — skip the atomic, reduce contention. Drop NaN/Inf gradients entirely.
Mixed-Precision Variant
The difference in rmsnorm_forward_w16:
@group(0) @binding(1) var<storage, read> w_w16: array<f16>;
// ...
y_w16[base + i] = x_w16[base + i] * rms_inv * f32(w_w16[i]);The w array is f16 storage; cast-load. Computation in fp32. Output still fp32.
Why is
x_w16also f32? Becausexis actually the output of the previous layer; we did not go layer-level f16 (that would be too aggressive). f16 is only on the weight side — embedding, norm w, matmul W. Activations are always fp32.
WGSL-Specific Notes
1. inverseSqrt built-in
The analog of Metal's rsqrt. Optimizes to a hardware rsqrt instruction (single cycle on Apple GPU).
2. The atomic CAS pattern is the same (see 01_embedding.md)
3. wg_reduce_sum returns the same total to every thread
This matters — rms_inv = inverseSqrt(total / D + eps) is the same on every thread, then every thread normalizes with the same coefficient in the strided loop. Determinism guaranteed.
4. var local_a = 0.0 is thread-private
Each thread keeps its own accumulator. Not in workgroup memory, but in a register. WG=256 → 256 separate register accumulators. The ALU bandwidth is enough.
Code Review
Finding 1: Does not use workgroup memory
| Risk | Explanation |
|---|---|
| 🟢 none | The reduction is subgroup-based, sh_red is only 1KB. No other WG memory. Correct. |
Finding 2: eps is the same f32 cast on every thread
| Risk | Explanation |
|---|---|
| 🟢 none | params.x is a uniform, in the GPU constant cache. Every thread reads the same value, a broadcast. Marginal. |
Finding 3: dy * 0.0 skip is correct for the atomic, but...
| Risk | Explanation |
|---|---|
| 🟡 minor | If dy[i] = 0.0 but x[i]·rms_inv ≠ 0, then dw_val = 0 regardless. The skip is correct. But in the case dy[i] != 0 && x[i] = 0 it will also skip — that is correct too. The filter is safe for NaN. |
Finding 4: Would computing coeff only on tid==0 save bandwidth?
| Risk | Explanation |
|---|---|
| 🟢 none | coeff = rms_inv³ · A / D — rms_inv and A are the same on every thread, so coeff produces the same result on every thread. Having all threads compute it in parallel is redundant but the cost is 1 register, free. Doing it only on tid==0 would require a shared memory broadcast; more expensive. |
Quick Checklist
| Test Scenario | Status |
|---|---|
Does d_model > WG reduce correctly? | ✅ strided loop |
Does d_model < WG (e.g. d=64) work correctly? | ✅ inactive threads contribute 0 to the reduce |
Is rms_inv saved/loaded correctly? | ✅ 1 f32 per row |
Is dw cross-row accumulation race-free? | ✅ CAS atomic add |
Is there a NaN when eps = 0? | ⚠ no formal test, but input is guaranteed non-zero |
Does rmsnorm_forward_w16 give the same result as rmsnorm_forward? | ⚠ there may be a tiny difference due to ε, not reviewed |
Next
03_linear_forward.md — the model's most expensive operation: the Y = X @ W matmul. Forward variants and the 64×64 tile algorithm.