llm.istanbul·Etüt
TR EN
Workbench →

rmsnorm — Root-Mean-Square Normalization

Dosya: 02_norm.wgsl Pipeline adımı: Forward'da her layer'da 2 kez (input + post-attention residual'dan önce ve FFN'den önce), backward'da iki kez.

3 kernel: rmsnorm_forward, rmsnorm_backward, rmsnorm_forward_w16 (mixed precision varyantı).


Nedir bu ya?

Bir müzik çalma listesi düşün: bir şarkı kulağını patlatıyor, sonraki neredeyse duyulmuyor. Sinir bozucu, değil mi? Bu yüzden çalarlar "ses eşitleme" (normalize) yapar — her parçanın genel ses seviyesini aynı hizaya çeker. Ama melodiyi, ritmi, enstrümanların birbirine oranını bozmaz; sadece toplam yüksekliği standart bir seviyeye ölçekler.

RMSNorm da modelin içinde dolaşan vektörlere tam bunu yapıyor. Bir layer'dan çıkan sayı vektörünün "yüksekliği" (büyüklüğü) bazen fırlar, bazen sönükleşir — ve bu dengesizlik eğitimi zorlaştırır. RMSNorm her vektörü alıp kendi RMS'ine (root-mean-square, yani değerlerin karelerinin ortalamasının karekökü — istatistikteki standart sapmanın akrabası) bölüyor. Sonuç: vektörün yönü/şekli aynı kalır, ama büyüklüğü tutarlı bir ölçeğe oturur.

Tek ek dokunuş: bölümden sonra bir w vektörüyle çarpıyoruz. Bu w öğrenilen bir kazanç ayarı (gain knob) — model "şu boyutu biraz daha yüksek tut, şunu kıs" diyebiliyor. Yani ham normalize değil, "normalize et, sonra öğrendiğin oranlara göre tekrar kur".

Belki LayerNorm'u duymuşsundur — o da benzer iş yapar ama önce vektörün ortalamasını çıkarır (sinyali sıfır etrafında ortalar). RMSNorm o adımı atlıyor: ortalama çıkarma yok, sadece büyüklüğe göre bölme. Daha az işlem, daha az parametre, ve pratikte (LLaMA, Mistral, Falcon) aynı işi görüyor. Asıl ilginç kısım ise bunu GPU'da bir vektörün tüm elemanlarını tek tek (seri olarak) toplamadan (reduction) nasıl hızlı yaptığın — aşağıda oraya geliyoruz.


Ne Yapar?

Vektörü kendi RMS'ine bölüp learnable scalar w ile çarpar. LayerNorm'un mean-substract'sız basit versiyonu — LLaMA'nın standardı.

rms_i = sqrt((1/d) · Σ_k x_i[k]²  +  ε)
y_i[d] = x_i[d] / rms_i · w[d]

x_i row vector (token bazlı), her row bağımsız işleniyor. Bir workgroup = bir row. WG=256 thread sıra paylaşıp reduction yapar.

LayerNorm'dan farkı:

  • Mean substract yok (faster, daha az ALU)
  • Bias yok (sadece scale γ, shift β yok)

LayerNorm:

y = (x - μ) / σ · γ + β

RMSNorm:

y = x / RMS · w

LLaMA, Mistral, Falcon hepsi RMSNorm kullanıyor. Pretrain'de tutarlı, daha az parametre.


Matematiksel Tanım — Backward

L = loss. Verilen dy = ∂L/∂y, hesaplanacak:

  • dx = ∂L/∂x — bir önceki layer'a iletmek için
  • dw += ∂L/∂w — weight güncellemesi için (cross-row accumulation)

Türev:

y[d] = x[d] · rms_inv · w[d]
where rms_inv = (mean(x²) + ε)^(-½)

Chain rule:

∂y[d]/∂x[d'] = (δ[d=d'] · rms_inv  +  x[d] · ∂rms_inv/∂x[d']) · w[d]

ve ∂rms_inv/∂x[d'] = -rms_inv³ · x[d'] / D (D = d_model)

Sonuç (bir tek row için):

A = Σ_k w[k] · x[k] · dy[k]                  ← scalar, workgroup reduction
dx[d] = w[d] · dy[d] · rms_inv
      - x[d] · rms_inv³ · A / D                ← coupling term
dw[d] += x[d] · rms_inv · dy[d]               ← cross-row, atomic add

A her row için hesaplanır. dw bütün row'ların toplamı (atomic accumulation).


Bind Group ABI

rmsnorm_forward (6 binding)

BindingTürDetay
0storage, readx: array<f32>[seq × d_model] input
1storage, readw: array<f32>[d_model] learnable scale
2storage, read_writey: array<f32>[seq × d_model] output
3storage, read_writerms_out: array<f32>[seq_len] saved for backward
4uniformdims: vec4<u32>(seq_len, d_model, _, _)
5uniformparams: vec4<f32>(eps, _, _, _)

rms_out neden saklanıyor? Backward'da yeniden hesaplamamak için. Forward sırasında nasılsa toplam alındı; sonucu sakla, backward'da reuse et. ~%30 backward speedup.

rmsnorm_backward (7 binding)

BindingTürDetay
0storage, readx
1storage, readw
2storage, readdy — upstream gradient
3storage, readrms_in — saved rms_inv
4storage, read_writedx — gradient to next layer back
5storage, read_writedw: array<atomic<u32>> — bit-cast f32, scatter-add
6uniformdims

rmsnorm_forward_w16 (6 binding)

Aynı; sadece w_w16: array<f16> ve hesaplamada f32(w_w16[i]) cast.


Dispatch Şekli

workgroup_size: 256
total threads:  seq_len workgroups × 256 threads

Bir workgroup = bir row. Her WG kendi row'unun reduction'ını yapıyor; row'lar bağımsız (independent parallel).

Örnek (seq=512, d=768): 512 WG × 256 thread.

Niye row başına bir WG? Reduction (sum of squares) workgroup-içi. Cross-row reduction gerekmediği için row'lar kolayca paralelleşir. Eğer d_model > WG (örn d=2048, WG=256) → her thread d_model/WG = 8 element işler, dağıtılmış reduction.


Satır Satır — rmsnorm_forward

1) Setup

wgsl
fn rmsnorm_forward(@builtin(workgroup_id) wgid: vec3<u32>,
                   @builtin(local_invocation_index) tid: u32) {
    let row = wgid.x;
    let seq_len = dims.x;
    let d_model = dims.y;
    let eps = params.x;
    if (row >= seq_len) { return; }

    let base = row * d_model;
  • workgroup_id (Metal'in threadgroup_position_in_grid analoğu) — her WG'nin grid içindeki ID'si.
  • local_invocation_index (Metal'in thread_index_in_threadgroup) — WG içindeki thread ID, 0..255.
  • base — bu row'un x/y array'lerinde başlangıç offset'i.

2) Phase 1 — Sum of squares

wgsl
var local_sum = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
    let v = x[base + i];
    local_sum = local_sum + v * v;
}
let total = wg_reduce_sum(tid, local_sum);

Strided loop:

  • Thread tid kendi paylaşına düşen elementleri toplar: x[base + tid], x[base + tid + 256], x[base + tid + 512], ...
  • Eğer d_model = 768, her thread 3 element işler (0/256/512, 1/257/513, ..., 255/511/767)
  • Eğer d_model = 256, her thread 1 element
  • Eğer d_model < 256, son birkaç thread loop'a hiç girmez (loop condition tid < d_model çoğu thread'de false)

local_sum thread-private. wg_reduce_sum(tid, local_sum) → tüm 256 thread'in toplamı her thread'e döner (subgroup reduction'la, 00_shared.md'da detay).

3) Phase 2 — Normalize and scale

wgsl
let rms_inv = inverseSqrt(total / f32(d_model) + eps);
if (tid == 0u) { rms_out[row] = rms_inv; }
for (var i = tid; i < d_model; i = i + WG) {
    y[base + i] = x[base + i] * rms_inv * w[i];
}
  • inverseSqrt(...)1/sqrt(...). WGSL built-in, hardware fast-path (rsqrt instruction).
  • tid == 0u ise rms_inv row'un saved-state buffer'ına yaz. Diğer thread'ler bunu yapmaz (single writer, race yok).
  • Tekrar strided loop: y[base + i] = x[base + i] * rms_inv * w[i].

Niye rms_inv her thread'de aynı? Çünkü wg_reduce_sum tüm thread'lere aynı total döndürüyor. Sonra inverseSqrt(same total + eps) → aynı rms_inv her thread'de. Race yok, deterministic.


Satır Satır — rmsnorm_backward

1) Setup + reduce A

wgsl
let row = wgid.x;
// ... (aynı)
let rms_inv = rms_in[row];      // load saved value

var local_a = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
    local_a = local_a + w[i] * x[base + i] * dy[base + i];
}
let A = wg_reduce_sum(tid, local_a);

A = Σ_k w[k]·x[k]·dy[k] workgroup-wide reduction.

2) coeff precompute

wgsl
let coeff = rms_inv * rms_inv * rms_inv * A / f32(d_model);

rms_inv³ · A / D — backward formula'sının coupling term'ünün katsayısı. Her thread tarafından hesaplanır (her şey scalar uniform).

3) Per-element dx and atomic dw accumulation

wgsl
for (var i = tid; i < d_model; i = i + WG) {
    let xi = x[base + i];
    let dyi = dy[base + i];
    let wi = w[i];
    dx[base + i] = wi * dyi * rms_inv - xi * coeff;
    let dw_val = xi * rms_inv * dyi;
    if (is_finite(dw_val) && dw_val != 0.0) {
        var old_bits = atomicLoad(&dw[i]);
        loop {
            let new_bits = bitcast<u32>(bitcast<f32>(old_bits) + dw_val);
            let res = atomicCompareExchangeWeak(&dw[i], old_bits, new_bits);
            if (res.exchanged) { break; }
            old_bits = res.old_value;
        }
    }
}

Per-thread iş:

  • Strided loop, her thread kendi i'sini işler
  • dx[base + i] — her slot tek thread tarafından dokunuluyor (race yok, atomic gerekmez)
  • dw[i]cross-row aynı slot! Row 0'ın thread'i dw[i]'a yazıyorsa, row 1'in thread'i de yazıyor. Bu yüzden CAS atomic add gerek.

is_finite + != 0.0 kontrolü: Sıfır gradient zaten add'in identity'si — atomic skip et, contention azalt. NaN/Inf gradient'i tamamen drop et.


Mixed-Precision Variant

rmsnorm_forward_w16 farkı:

wgsl
@group(0) @binding(1) var<storage, read>       w_w16: array<f16>;
// ...
y_w16[base + i] = x_w16[base + i] * rms_inv * f32(w_w16[i]);

w array'i f16 storage; cast-load. Hesap fp32. Output yine fp32.

Niye x_w16 da f32? Çünkü x aslında bir önceki layer'ın output'u; biz layer-level f16 yapmadık (çok agresif olur). f16 sadece weight tarafında — embedding, norm w, matmul W. Activation hep fp32.


WGSL-Spesifik Notlar

1. inverseSqrt built-in

Metal'in rsqrt analoğu. Hardware rsqrt instruction'a optimize olur (Apple GPU'da single cycle).

2. Atomic CAS pattern aynı (bkz 01_embedding.md)

3. wg_reduce_sum her thread'e aynı total dönderiyor

Bu önemli — rms_inv = inverseSqrt(total / D + eps) her thread'de aynı, sonra her thread strided loop'ta aynı katsayıyla normalize ediyor. Determinism garantili.

4. var local_a = 0.0 thread-private

Her thread kendi accumulator'ı tutar. Workgroup memory'de değil, register'da. WG=256 → 256 ayrı register accumulator. ALU bandwidth'i yeter.


Code Review

Bulgu 1: Workgroup memory kullanmıyor

RiskAçıklama
🟢 yokReduction subgroup-based, sh_red 1KB sadece. Diğer WG memory yok. Doğru.

Bulgu 2: eps her thread'de aynı f32 cast

RiskAçıklama
🟢 yokparams.x uniform, GPU constant cache'te. Her thread aynı değer okur, broadcast. Marjinal.

Bulgu 3: dy * 0.0 skip atomik için doğru ama...

RiskAçıklama
🟡 minorEğer dy[i] = 0.0 ama x[i]·rms_inv ≠ 0 ise yine de dw_val = 0. Skip doğru. Ama dy[i] != 0 && x[i] = 0 durumunda da skip yapacak — bu da doğru. NaN için filter güvenli.

Bulgu 4: coeff sadece tid==0 hesaplansa bandwidth tasarrufu olur mu?

RiskAçıklama
🟢 yokcoeff = rms_inv³ · A / Drms_inv ve A her thread'de aynı, coeff her thread'de aynı sonucu üretir. Tüm thread'lerin paralel hesaplaması redundant ama maliyet 1 register, bedava. Sadece tid==0 yapsa shared memory broadcast gerek; daha pahalı.

Hızlı Kontrol Listesi

Test SenaryosuDurum
d_model > WG doğru reduction yapıyor mu?✅ strided loop
d_model < WG (örn d=64) doğru çalışıyor mu?✅ aktif olmayan thread'ler reduce'a 0 katar
rms_inv saved/loaded doğru mu?✅ row başına 1 f32
dw cross-row accumulation race-free mi?✅ CAS atomic add
eps = 0 durumunda NaN var mı?⚠ formal test yok, ama input non-zero garantili
rmsnorm_forward_w16 ile rmsnorm_forward aynı sonuç mu?⚠ ε'a göre minik fark olabilir, gözden geçirilmedi

Sonraki

03_linear_forward.md — modelin en pahalı işlemi: Y = X @ W matmul. Forward variants ve 64×64 tile algoritması.

WGSL kernel etüdleri · WebGPU üzerinde sıfırdan LLMİstanbul’da Uğur Toprakdeviren tarafından hazırlandı.