llm.istanbul·Study
TR EN
Workbench →

rmsnorm — Root-Mean-Square Normalization

File: 02_norm.wgsl Pipeline step: Twice per layer in the forward pass (before the input + post-attention residual, and before the FFN), and twice in the backward pass.

3 kernels: rmsnorm_forward, rmsnorm_backward, rmsnorm_forward_w16 (mixed precision variant).


Wait, what is this?

Picture a music playlist: one track blasts your ears off, the next is barely audible. Annoying, right? That's why players run "loudness normalization" — they pull every track's overall volume into line. But they don't touch the melody, the rhythm, or how the instruments sit relative to each other; they only scale the total loudness to a standard level.

RMSNorm does exactly that to the vectors flowing through the model. The "loudness" (magnitude) of a number vector coming out of a layer sometimes spikes, sometimes fades — and that imbalance makes training harder. RMSNorm takes each vector and divides it by its own RMS (root-mean-square, i.e. the square root of the mean of the squared values — a cousin of the standard deviation from stats). The result: the vector's direction/shape stays the same, but its magnitude settles onto a consistent scale.

There's one extra touch: after the division, we multiply by a w vector. That w is a learned gain knob — the model gets to say "keep this dimension a bit louder, turn that one down." So it's not a raw normalize; it's "normalize, then dial it back to whatever proportions you've learned."

You may have heard of LayerNorm — it does a similar job, but it first subtracts the vector's mean (centers the signal around zero). RMSNorm skips that step: no mean subtraction, just division by magnitude. Fewer ops, fewer parameters, and in practice (LLaMA, Mistral, Falcon) it does the same job. The genuinely interesting part is how you make this fast on a GPU without summing every element of a vector serially (the reduction) — we get to that below.


What Does It Do?

It divides the vector by its own RMS and multiplies by the learnable scalar w. A simple version of LayerNorm without mean subtraction — LLaMA's standard.

rms_i = sqrt((1/d) · Σ_k x_i[k]²  +  ε)
y_i[d] = x_i[d] / rms_i · w[d]

x_i is a row vector (per token), each row processed independently. One workgroup = one row. A WG of 256 threads shares the row and performs the reduction.

The difference from LayerNorm:

  • No mean subtraction (faster, less ALU)
  • No bias (only scale γ, no shift β)

LayerNorm:

y = (x - μ) / σ · γ + β

RMSNorm:

y = x / RMS · w

LLaMA, Mistral, and Falcon all use RMSNorm. Consistent during pretraining, fewer parameters.


Mathematical Definition — Backward

L = loss. Given dy = ∂L/∂y, what we compute:

  • dx = ∂L/∂x — to pass back to the previous layer
  • dw += ∂L/∂w — for the weight update (cross-row accumulation)

Derivative:

y[d] = x[d] · rms_inv · w[d]
where rms_inv = (mean(x²) + ε)^(-½)

Chain rule:

∂y[d]/∂x[d'] = (δ[d=d'] · rms_inv  +  x[d] · ∂rms_inv/∂x[d']) · w[d]

and ∂rms_inv/∂x[d'] = -rms_inv³ · x[d'] / D (D = d_model)

Result (for a single row):

A = Σ_k w[k] · x[k] · dy[k]                  ← scalar, workgroup reduction
dx[d] = w[d] · dy[d] · rms_inv
      - x[d] · rms_inv³ · A / D                ← coupling term
dw[d] += x[d] · rms_inv · dy[d]               ← cross-row, atomic add

A is computed for each row. dw is the sum over all rows (atomic accumulation).


Bind Group ABI

rmsnorm_forward (6 bindings)

BindingTypeDetail
0storage, readx: array<f32>[seq × d_model] input
1storage, readw: array<f32>[d_model] learnable scale
2storage, read_writey: array<f32>[seq × d_model] output
3storage, read_writerms_out: array<f32>[seq_len] saved for backward
4uniformdims: vec4<u32>(seq_len, d_model, _, _)
5uniformparams: vec4<f32>(eps, _, _, _)

Why is rms_out saved? To avoid recomputing it in the backward pass. The sum was taken during forward anyway; save the result and reuse it in backward. ~30% backward speedup.

rmsnorm_backward (7 bindings)

BindingTypeDetail
0storage, readx
1storage, readw
2storage, readdy — upstream gradient
3storage, readrms_in — saved rms_inv
4storage, read_writedx — gradient to next layer back
5storage, read_writedw: array<atomic<u32>> — bit-cast f32, scatter-add
6uniformdims

rmsnorm_forward_w16 (6 bindings)

Same; only w_w16: array<f16> and an f32(w_w16[i]) cast in the computation.


Dispatch Shape

workgroup_size: 256
total threads:  seq_len workgroups × 256 threads

One workgroup = one row. Each WG performs the reduction for its own row; rows are independent (independent parallel).

Example (seq=512, d=768): 512 WG × 256 threads.

Why one WG per row? The reduction (sum of squares) is within the workgroup. Since no cross-row reduction is needed, rows parallelize easily. If d_model > WG (e.g. d=2048, WG=256) → each thread processes d_model/WG = 8 elements, a distributed reduction.


Line by Line — rmsnorm_forward

1) Setup

wgsl
fn rmsnorm_forward(@builtin(workgroup_id) wgid: vec3<u32>,
                   @builtin(local_invocation_index) tid: u32) {
    let row = wgid.x;
    let seq_len = dims.x;
    let d_model = dims.y;
    let eps = params.x;
    if (row >= seq_len) { return; }

    let base = row * d_model;
  • workgroup_id (the analog of Metal's threadgroup_position_in_grid) — each WG's ID within the grid.
  • local_invocation_index (the analog of Metal's thread_index_in_threadgroup) — the thread ID within the WG, 0..255.
  • base — this row's starting offset in the x/y arrays.

2) Phase 1 — Sum of squares

wgsl
var local_sum = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
    let v = x[base + i];
    local_sum = local_sum + v * v;
}
let total = wg_reduce_sum(tid, local_sum);

Strided loop:

  • Thread tid sums the elements that fall to its share: x[base + tid], x[base + tid + 256], x[base + tid + 512], ...
  • If d_model = 768, each thread processes 3 elements (0/256/512, 1/257/513, ..., 255/511/767)
  • If d_model = 256, each thread processes 1 element
  • If d_model < 256, the last few threads never enter the loop (the loop condition tid < d_model is false on most threads)

local_sum is thread-private. wg_reduce_sum(tid, local_sum) → the sum across all 256 threads is returned to every thread (via subgroup reduction, detailed in 00_shared.md).

3) Phase 2 — Normalize and scale

wgsl
let rms_inv = inverseSqrt(total / f32(d_model) + eps);
if (tid == 0u) { rms_out[row] = rms_inv; }
for (var i = tid; i < d_model; i = i + WG) {
    y[base + i] = x[base + i] * rms_inv * w[i];
}
  • inverseSqrt(...)1/sqrt(...). A WGSL built-in, hardware fast-path (rsqrt instruction).
  • If tid == 0u, write rms_inv into the row's saved-state buffer. Other threads do not do this (single writer, no race).
  • Strided loop again: y[base + i] = x[base + i] * rms_inv * w[i].

Why is rms_inv the same on every thread? Because wg_reduce_sum returns the same total to all threads. Then inverseSqrt(same total + eps) → the same rms_inv on every thread. No race, deterministic.


Line by Line — rmsnorm_backward

1) Setup + reduce A

wgsl
let row = wgid.x;
// ... (same)
let rms_inv = rms_in[row];      // load saved value

var local_a = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
    local_a = local_a + w[i] * x[base + i] * dy[base + i];
}
let A = wg_reduce_sum(tid, local_a);

A = Σ_k w[k]·x[k]·dy[k] is a workgroup-wide reduction.

2) coeff precompute

wgsl
let coeff = rms_inv * rms_inv * rms_inv * A / f32(d_model);

rms_inv³ · A / D — the coefficient of the coupling term in the backward formula. Computed by each thread (everything is a scalar uniform).

3) Per-element dx and atomic dw accumulation

wgsl
for (var i = tid; i < d_model; i = i + WG) {
    let xi = x[base + i];
    let dyi = dy[base + i];
    let wi = w[i];
    dx[base + i] = wi * dyi * rms_inv - xi * coeff;
    let dw_val = xi * rms_inv * dyi;
    if (is_finite(dw_val) && dw_val != 0.0) {
        var old_bits = atomicLoad(&dw[i]);
        loop {
            let new_bits = bitcast<u32>(bitcast<f32>(old_bits) + dw_val);
            let res = atomicCompareExchangeWeak(&dw[i], old_bits, new_bits);
            if (res.exchanged) { break; }
            old_bits = res.old_value;
        }
    }
}

Per-thread work:

  • Strided loop, each thread processes its own i
  • dx[base + i] — each slot is touched by a single thread (no race, no atomic needed)
  • dw[i]the same slot across rows! If row 0's thread writes to dw[i], row 1's thread also writes to it. That is why a CAS atomic add is needed.

The is_finite + != 0.0 check: A zero gradient is already the identity of the add — skip the atomic, reduce contention. Drop NaN/Inf gradients entirely.


Mixed-Precision Variant

The difference in rmsnorm_forward_w16:

wgsl
@group(0) @binding(1) var<storage, read>       w_w16: array<f16>;
// ...
y_w16[base + i] = x_w16[base + i] * rms_inv * f32(w_w16[i]);

The w array is f16 storage; cast-load. Computation in fp32. Output still fp32.

Why is x_w16 also f32? Because x is actually the output of the previous layer; we did not go layer-level f16 (that would be too aggressive). f16 is only on the weight side — embedding, norm w, matmul W. Activations are always fp32.


WGSL-Specific Notes

1. inverseSqrt built-in

The analog of Metal's rsqrt. Optimizes to a hardware rsqrt instruction (single cycle on Apple GPU).

2. The atomic CAS pattern is the same (see 01_embedding.md)

3. wg_reduce_sum returns the same total to every thread

This matters — rms_inv = inverseSqrt(total / D + eps) is the same on every thread, then every thread normalizes with the same coefficient in the strided loop. Determinism guaranteed.

4. var local_a = 0.0 is thread-private

Each thread keeps its own accumulator. Not in workgroup memory, but in a register. WG=256 → 256 separate register accumulators. The ALU bandwidth is enough.


Code Review

Finding 1: Does not use workgroup memory

RiskExplanation
🟢 noneThe reduction is subgroup-based, sh_red is only 1KB. No other WG memory. Correct.

Finding 2: eps is the same f32 cast on every thread

RiskExplanation
🟢 noneparams.x is a uniform, in the GPU constant cache. Every thread reads the same value, a broadcast. Marginal.

Finding 3: dy * 0.0 skip is correct for the atomic, but...

RiskExplanation
🟡 minorIf dy[i] = 0.0 but x[i]·rms_inv ≠ 0, then dw_val = 0 regardless. The skip is correct. But in the case dy[i] != 0 && x[i] = 0 it will also skip — that is correct too. The filter is safe for NaN.

Finding 4: Would computing coeff only on tid==0 save bandwidth?

RiskExplanation
🟢 nonecoeff = rms_inv³ · A / Drms_inv and A are the same on every thread, so coeff produces the same result on every thread. Having all threads compute it in parallel is redundant but the cost is 1 register, free. Doing it only on tid==0 would require a shared memory broadcast; more expensive.

Quick Checklist

Test ScenarioStatus
Does d_model > WG reduce correctly?✅ strided loop
Does d_model < WG (e.g. d=64) work correctly?✅ inactive threads contribute 0 to the reduce
Is rms_inv saved/loaded correctly?✅ 1 f32 per row
Is dw cross-row accumulation race-free?✅ CAS atomic add
Is there a NaN when eps = 0?⚠ no formal test, but input is guaranteed non-zero
Does rmsnorm_forward_w16 give the same result as rmsnorm_forward?⚠ there may be a tiny difference due to ε, not reviewed

Next

03_linear_forward.md — the model's most expensive operation: the Y = X @ W matmul. Forward variants and the 64×64 tile algorithm.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.