rmsnorm — Root-Mean-Square Normalization
Dosya: 02_norm.wgsl Pipeline adımı: Forward'da her layer'da 2 kez (input + post-attention residual'dan önce ve FFN'den önce), backward'da iki kez.
3 kernel: rmsnorm_forward, rmsnorm_backward, rmsnorm_forward_w16 (mixed precision varyantı).
Nedir bu ya?
Bir müzik çalma listesi düşün: bir şarkı kulağını patlatıyor, sonraki neredeyse duyulmuyor. Sinir bozucu, değil mi? Bu yüzden çalarlar "ses eşitleme" (normalize) yapar — her parçanın genel ses seviyesini aynı hizaya çeker. Ama melodiyi, ritmi, enstrümanların birbirine oranını bozmaz; sadece toplam yüksekliği standart bir seviyeye ölçekler.
RMSNorm da modelin içinde dolaşan vektörlere tam bunu yapıyor. Bir layer'dan çıkan sayı vektörünün "yüksekliği" (büyüklüğü) bazen fırlar, bazen sönükleşir — ve bu dengesizlik eğitimi zorlaştırır. RMSNorm her vektörü alıp kendi RMS'ine (root-mean-square, yani değerlerin karelerinin ortalamasının karekökü — istatistikteki standart sapmanın akrabası) bölüyor. Sonuç: vektörün yönü/şekli aynı kalır, ama büyüklüğü tutarlı bir ölçeğe oturur.
Tek ek dokunuş: bölümden sonra bir w vektörüyle çarpıyoruz. Bu w öğrenilen bir kazanç ayarı (gain knob) — model "şu boyutu biraz daha yüksek tut, şunu kıs" diyebiliyor. Yani ham normalize değil, "normalize et, sonra öğrendiğin oranlara göre tekrar kur".
Belki LayerNorm'u duymuşsundur — o da benzer iş yapar ama önce vektörün ortalamasını çıkarır (sinyali sıfır etrafında ortalar). RMSNorm o adımı atlıyor: ortalama çıkarma yok, sadece büyüklüğe göre bölme. Daha az işlem, daha az parametre, ve pratikte (LLaMA, Mistral, Falcon) aynı işi görüyor. Asıl ilginç kısım ise bunu GPU'da bir vektörün tüm elemanlarını tek tek (seri olarak) toplamadan (reduction) nasıl hızlı yaptığın — aşağıda oraya geliyoruz.
Ne Yapar?
Vektörü kendi RMS'ine bölüp learnable scalar w ile çarpar. LayerNorm'un mean-substract'sız basit versiyonu — LLaMA'nın standardı.
rms_i = sqrt((1/d) · Σ_k x_i[k]² + ε)
y_i[d] = x_i[d] / rms_i · w[d]x_i row vector (token bazlı), her row bağımsız işleniyor. Bir workgroup = bir row. WG=256 thread sıra paylaşıp reduction yapar.
LayerNorm'dan farkı:
- Mean substract yok (faster, daha az ALU)
- Bias yok (sadece scale γ, shift β yok)
LayerNorm:
y = (x - μ) / σ · γ + βRMSNorm:
y = x / RMS · wLLaMA, Mistral, Falcon hepsi RMSNorm kullanıyor. Pretrain'de tutarlı, daha az parametre.
Matematiksel Tanım — Backward
L = loss. Verilen dy = ∂L/∂y, hesaplanacak:
dx = ∂L/∂x— bir önceki layer'a iletmek içindw += ∂L/∂w— weight güncellemesi için (cross-row accumulation)
Türev:
y[d] = x[d] · rms_inv · w[d]
where rms_inv = (mean(x²) + ε)^(-½)Chain rule:
∂y[d]/∂x[d'] = (δ[d=d'] · rms_inv + x[d] · ∂rms_inv/∂x[d']) · w[d]ve ∂rms_inv/∂x[d'] = -rms_inv³ · x[d'] / D (D = d_model)
Sonuç (bir tek row için):
A = Σ_k w[k] · x[k] · dy[k] ← scalar, workgroup reduction
dx[d] = w[d] · dy[d] · rms_inv
- x[d] · rms_inv³ · A / D ← coupling term
dw[d] += x[d] · rms_inv · dy[d] ← cross-row, atomic addA her row için hesaplanır. dw bütün row'ların toplamı (atomic accumulation).
Bind Group ABI
rmsnorm_forward (6 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | x: array<f32> — [seq × d_model] input |
| 1 | storage, read | w: array<f32> — [d_model] learnable scale |
| 2 | storage, read_write | y: array<f32> — [seq × d_model] output |
| 3 | storage, read_write | rms_out: array<f32> — [seq_len] saved for backward |
| 4 | uniform | dims: vec4<u32> — (seq_len, d_model, _, _) |
| 5 | uniform | params: vec4<f32> — (eps, _, _, _) |
rms_outneden saklanıyor? Backward'da yeniden hesaplamamak için. Forward sırasında nasılsa toplam alındı; sonucu sakla, backward'da reuse et. ~%30 backward speedup.
rmsnorm_backward (7 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | x |
| 1 | storage, read | w |
| 2 | storage, read | dy — upstream gradient |
| 3 | storage, read | rms_in — saved rms_inv |
| 4 | storage, read_write | dx — gradient to next layer back |
| 5 | storage, read_write | dw: array<atomic<u32>> — bit-cast f32, scatter-add |
| 6 | uniform | dims |
rmsnorm_forward_w16 (6 binding)
Aynı; sadece w_w16: array<f16> ve hesaplamada f32(w_w16[i]) cast.
Dispatch Şekli
workgroup_size: 256
total threads: seq_len workgroups × 256 threadsBir workgroup = bir row. Her WG kendi row'unun reduction'ını yapıyor; row'lar bağımsız (independent parallel).
Örnek (seq=512, d=768): 512 WG × 256 thread.
Niye row başına bir WG? Reduction (sum of squares) workgroup-içi. Cross-row reduction gerekmediği için row'lar kolayca paralelleşir. Eğer
d_model > WG(örn d=2048, WG=256) → her threadd_model/WG = 8element işler, dağıtılmış reduction.
Satır Satır — rmsnorm_forward
1) Setup
fn rmsnorm_forward(@builtin(workgroup_id) wgid: vec3<u32>,
@builtin(local_invocation_index) tid: u32) {
let row = wgid.x;
let seq_len = dims.x;
let d_model = dims.y;
let eps = params.x;
if (row >= seq_len) { return; }
let base = row * d_model;workgroup_id(Metal'inthreadgroup_position_in_gridanaloğu) — her WG'nin grid içindeki ID'si.local_invocation_index(Metal'inthread_index_in_threadgroup) — WG içindeki thread ID,0..255.base— bu row'un x/y array'lerinde başlangıç offset'i.
2) Phase 1 — Sum of squares
var local_sum = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
let v = x[base + i];
local_sum = local_sum + v * v;
}
let total = wg_reduce_sum(tid, local_sum);Strided loop:
- Thread
tidkendi paylaşına düşen elementleri toplar:x[base + tid],x[base + tid + 256],x[base + tid + 512], ... - Eğer
d_model = 768, her thread 3 element işler (0/256/512, 1/257/513, ..., 255/511/767) - Eğer
d_model = 256, her thread 1 element - Eğer
d_model < 256, son birkaç thread loop'a hiç girmez (loop conditiontid < d_modelçoğu thread'de false)
local_sum thread-private. wg_reduce_sum(tid, local_sum) → tüm 256 thread'in toplamı her thread'e döner (subgroup reduction'la, 00_shared.md'da detay).
3) Phase 2 — Normalize and scale
let rms_inv = inverseSqrt(total / f32(d_model) + eps);
if (tid == 0u) { rms_out[row] = rms_inv; }
for (var i = tid; i < d_model; i = i + WG) {
y[base + i] = x[base + i] * rms_inv * w[i];
}inverseSqrt(...)→1/sqrt(...). WGSL built-in, hardware fast-path (rsqrtinstruction).tid == 0uiserms_invrow'un saved-state buffer'ına yaz. Diğer thread'ler bunu yapmaz (single writer, race yok).- Tekrar strided loop:
y[base + i] = x[base + i] * rms_inv * w[i].
Niye
rms_invher thread'de aynı? Çünküwg_reduce_sumtüm thread'lere aynı total döndürüyor. SonrainverseSqrt(same total + eps)→ aynırms_invher thread'de. Race yok, deterministic.
Satır Satır — rmsnorm_backward
1) Setup + reduce A
let row = wgid.x;
// ... (aynı)
let rms_inv = rms_in[row]; // load saved value
var local_a = 0.0;
for (var i = tid; i < d_model; i = i + WG) {
local_a = local_a + w[i] * x[base + i] * dy[base + i];
}
let A = wg_reduce_sum(tid, local_a);A = Σ_k w[k]·x[k]·dy[k] workgroup-wide reduction.
2) coeff precompute
let coeff = rms_inv * rms_inv * rms_inv * A / f32(d_model);rms_inv³ · A / D — backward formula'sının coupling term'ünün katsayısı. Her thread tarafından hesaplanır (her şey scalar uniform).
3) Per-element dx and atomic dw accumulation
for (var i = tid; i < d_model; i = i + WG) {
let xi = x[base + i];
let dyi = dy[base + i];
let wi = w[i];
dx[base + i] = wi * dyi * rms_inv - xi * coeff;
let dw_val = xi * rms_inv * dyi;
if (is_finite(dw_val) && dw_val != 0.0) {
var old_bits = atomicLoad(&dw[i]);
loop {
let new_bits = bitcast<u32>(bitcast<f32>(old_bits) + dw_val);
let res = atomicCompareExchangeWeak(&dw[i], old_bits, new_bits);
if (res.exchanged) { break; }
old_bits = res.old_value;
}
}
}Per-thread iş:
- Strided loop, her thread kendi
i'sini işler dx[base + i]— her slot tek thread tarafından dokunuluyor (race yok, atomic gerekmez)dw[i]— cross-row aynı slot! Row 0'ın thread'idw[i]'a yazıyorsa, row 1'in thread'i de yazıyor. Bu yüzden CAS atomic add gerek.
is_finite + != 0.0 kontrolü: Sıfır gradient zaten add'in identity'si — atomic skip et, contention azalt. NaN/Inf gradient'i tamamen drop et.
Mixed-Precision Variant
rmsnorm_forward_w16 farkı:
@group(0) @binding(1) var<storage, read> w_w16: array<f16>;
// ...
y_w16[base + i] = x_w16[base + i] * rms_inv * f32(w_w16[i]);w array'i f16 storage; cast-load. Hesap fp32. Output yine fp32.
Niye
x_w16da f32? Çünküxaslında bir önceki layer'ın output'u; biz layer-level f16 yapmadık (çok agresif olur). f16 sadece weight tarafında — embedding, norm w, matmul W. Activation hep fp32.
WGSL-Spesifik Notlar
1. inverseSqrt built-in
Metal'in rsqrt analoğu. Hardware rsqrt instruction'a optimize olur (Apple GPU'da single cycle).
2. Atomic CAS pattern aynı (bkz 01_embedding.md)
3. wg_reduce_sum her thread'e aynı total dönderiyor
Bu önemli — rms_inv = inverseSqrt(total / D + eps) her thread'de aynı, sonra her thread strided loop'ta aynı katsayıyla normalize ediyor. Determinism garantili.
4. var local_a = 0.0 thread-private
Her thread kendi accumulator'ı tutar. Workgroup memory'de değil, register'da. WG=256 → 256 ayrı register accumulator. ALU bandwidth'i yeter.
Code Review
Bulgu 1: Workgroup memory kullanmıyor
| Risk | Açıklama |
|---|---|
| 🟢 yok | Reduction subgroup-based, sh_red 1KB sadece. Diğer WG memory yok. Doğru. |
Bulgu 2: eps her thread'de aynı f32 cast
| Risk | Açıklama |
|---|---|
| 🟢 yok | params.x uniform, GPU constant cache'te. Her thread aynı değer okur, broadcast. Marjinal. |
Bulgu 3: dy * 0.0 skip atomik için doğru ama...
| Risk | Açıklama |
|---|---|
| 🟡 minor | Eğer dy[i] = 0.0 ama x[i]·rms_inv ≠ 0 ise yine de dw_val = 0. Skip doğru. Ama dy[i] != 0 && x[i] = 0 durumunda da skip yapacak — bu da doğru. NaN için filter güvenli. |
Bulgu 4: coeff sadece tid==0 hesaplansa bandwidth tasarrufu olur mu?
| Risk | Açıklama |
|---|---|
| 🟢 yok | coeff = rms_inv³ · A / D — rms_inv ve A her thread'de aynı, coeff her thread'de aynı sonucu üretir. Tüm thread'lerin paralel hesaplaması redundant ama maliyet 1 register, bedava. Sadece tid==0 yapsa shared memory broadcast gerek; daha pahalı. |
Hızlı Kontrol Listesi
| Test Senaryosu | Durum |
|---|---|
d_model > WG doğru reduction yapıyor mu? | ✅ strided loop |
d_model < WG (örn d=64) doğru çalışıyor mu? | ✅ aktif olmayan thread'ler reduce'a 0 katar |
rms_inv saved/loaded doğru mu? | ✅ row başına 1 f32 |
dw cross-row accumulation race-free mi? | ✅ CAS atomic add |
eps = 0 durumunda NaN var mı? | ⚠ formal test yok, ama input non-zero garantili |
rmsnorm_forward_w16 ile rmsnorm_forward aynı sonuç mu? | ⚠ ε'a göre minik fark olabilir, gözden geçirilmedi |
Sonraki
03_linear_forward.md — modelin en pahalı işlemi: Y = X @ W matmul. Forward variants ve 64×64 tile algoritması.