llm.istanbul·Etüt
TR EN
Workbench →

cross_entropy ve sum_losses — Fused Loss + dLogits

Dosya: 07_loss.wgsl Pipeline adımı: Forward'ın son adımı + backward'ın başlangıcı. lm_head logits'i alıp loss + gradient'i tek geçişte üretir.

İki kernel:

  • cross_entropy — her row için loss + dLogits in-place
  • sum_losses — per-row losses → scalar total

Nedir bu ya?

Modeli, her adımda "sıradaki kelime ne?" sorusunu çözen bir öğrenci gibi düşün. Ama bu öğrenci tek bir cevap vermiyor — vocab'daki her kelimeye bir bahis koyuyor: "sıradaki kelime %60 ihtimalle 'ev', %15 'araba', %0.001 'muz'…". Bu bahisler toplamı 1 olan bir olasılık dağılımı. İşte softmax'in yaptığı bu: ham logit skorlarını olasılığa çeviriyor.

Loss da şunu ölçen bir sınav puanı: doğru cevap neydi, ve model o doğru cevaba ne kadar olasılık vermişti? Doğru kelimeye %95 dediyse harika, puan (loss) neredeyse sıfır. Ama doğru kelimeye %2 deyip yanında kendinden emin bir şekilde başka bir şeye %90 dediyse — işte o zaman ağır ceza yer. Cross-entropy aslında "model ne kadar şaşırdı?" sorusunun ölçüsü: -log(doğru kelimeye verilen olasılık). Emin ve haklıysan log neredeyse 0; emin ama haksızsan log patlıyor.

"Log-softmax fused" kısmı pratik bir numara. Normalde önce softmax (olasılıkları üret), sonra log al, sonra doğru olanı seç derdin — üç ayrı adım, taşma riski (exp çok büyük sayılarda patlar). Bunun yerine kernel hepsini tek geçişte, sayısal olarak güvenli şekilde (row_max çıkararak) yapıyor. Üstüne, loss'u hesaplarken backward'ın ihtiyacı olan gradient'i de aynı tabloya yazıyor — iki iş bir oturuşta.

Ve burası önemli: bu kernel forward'ın bittiği, backward'ın başladığı yer. "Ne kadar yanlıştık" sayısını ürettiği an, "o hatayı düzeltmek için her logit'i hangi yöne itmeliyiz" sinyalini de çıkarıyor. Geri yayılımın ilk dominosu burada devriliyor.


Ne Yapar?

cross_entropy

Standard cross-entropy + label smoothing (α) + z-loss (β) fused:

softmax: p[v] = exp(l[v] - row_max) / Σ_v' exp(l[v'] - row_max)
ce_loss[s] = -log p[tgt[s]]
            = LSE - l[tgt[s]]                  where LSE = row_max + log Σ exp(l - row_max)

with label smoothing (α):
    loss[s] = LSE - (1-α)·l[tgt] - α·mean(l)

with z-loss (β):
    loss[s] += β·LSE²

dLogits[s, v] = (p[v] - target_dist[v]) / norm
              + 2β·LSE·p[v] / norm
              where target_dist = (1-α)·one_hot + α/V uniform

norm = (valid token count), host'tan inv_norm = 1/norm gönderilir.

sum_losses

Per-row losses array'ini tek scalar'a indiriyor (1 WG ile).


Niye In-place dLogits?

Backward'da dLogits lm_head'in bir önceki step'i — dLogits[s, v] vocab_size boyutunda ayrı buffer olsa seq × vocab × 4B byte memory tutardı. Örnekle:

  • seq=512, vocab=16384 → 32 MB ekstra activation memory
  • vocab=50K (LLaMA-3 boyutu) → 100 MB

logits zaten [seq × vocab] aynı boyutta. In-place yaz:

  1. Phase 1-3: logits'i sadece oku (max, denom, target lookup)
  2. Phase 4: her thread kendi paylaşına düşen v'lerin logits'ini okur, dLogits'i hesaplar, aynı offset'e yazar

Read-before-write per thread, no cross-thread alias → race-free.

32-100 MB activation memory tasarrufu. Önemli.


Bind Group ABI

cross_entropy (5 binding)

BindingTürDetay
0storage, read_writelogits: array<f32>[seq × vocab] (in-place dLogits)
1storage, readtgts: array<u32>[seq] target token IDs
2storage, read_writelosses: array<f32>[seq] per-row loss
3uniformdims: vec4<u32>(seq_len, vocab_size, _, _)
4uniformparams: vec4<f32>(inv_norm, alpha, beta_zl, inv_V)

sum_losses (3 binding)

BindingTürDetay
0storage, readlosses
1storage, read_writeout_sum: array<f32>[1]
2uniformn: u32

Dispatch Şekli

cross_entropy

workgroup_size: 256
grid:           (seq_len, 1, 1)

1 WG = 1 row (1 query position'un loss'u + dLogits'i).

sum_losses

workgroup_size: 256
grid:           (1, 1, 1)        ← TEK workgroup

Single WG, strided loop ile tüm seq_len losses'i topla.


Satır Satır — cross_entropy

1) Setup + ignore check

wgsl
let s = wgid.x;
let row_base = s * vocab_size;
let tgt = tgts[s];

if (tgt >= vocab_size) {
    if (tid == 0u) { losses[s] = 0.0; }
    for (var v = tid; v < vocab_size; v = v + WG) {
        logits[row_base + v] = 0.0;
    }
    return;
}

Out-of-range target (örn padding token, vocab dışı) → loss=0 ve dLogits=0. Bu row gradient'a katkı sağlamaz. Padding-aware training için.

2) Phase 1 — Row max

wgsl
var local_max: f32 = NEG_INF;
for (var v = tid; v < vocab_size; v = v + WG) {
    local_max = max(local_max, logits[row_base + v]);
}
let row_max = wg_reduce_max(tid, local_max);

Strided loop + workgroup reduce. vocab_size = 16384, WG=256 → her thread 64 element işler.

row_max numerical stability için. exp(l - row_max) her zaman ≤ 1, overflow yok.

3) Phase 2a — Denom (sum of exp)

wgsl
var local_sum: f32 = 0.0;
for (var v = tid; v < vocab_size; v = v + WG) {
    local_sum = local_sum + exp(logits[row_base + v] - row_max);
}
let denom = wg_reduce_sum(tid, local_sum);

denom = Σ exp(l - max). Standard log-sum-exp setup.

4) Phase 2b — Sum of logits (label smoothing only)

wgsl
var sum_l: f32 = 0.0;
if (alpha > 0.0) {
    workgroupBarrier();
    var local_l: f32 = 0.0;
    for (var v = tid; v < vocab_size; v = v + WG) {
        local_l = local_l + logits[row_base + v];
    }
    sum_l = wg_reduce_sum(tid, local_l);
}

alpha = 0 ise tamamen atla — seq_len × vocab_size extra read kazanılır. alpha uniform → tüm thread'ler aynı yola gider, branch divergence yok. Barrier safe.

mean_l = sum_l / V label smoothing'in regularization terimi.

5) Phase 3 — Loss

wgsl
let inv_denom = 1.0 / max(denom, 1e-30);
let lse       = row_max + log(max(denom, 1e-30));
let mean_l    = sum_l * inv_V;

if (tid == 0u) {
    let l_t  = logits[row_base + tgt];
    let ce   = lse - (1.0 - alpha) * l_t - alpha * mean_l;
    let zl   = beta_zl * lse * lse;
    losses[s] = ce + zl;
}
  • lse = log Σ exp(l) — numerically stable
  • ce = -log p[tgt] simplified: lse - l[tgt] (when α=0)
  • Label smoothing: (1-α)·l_t + α·mean_l ile mix
  • Z-loss: β·lse² LSE'i 0'a doğru regularize, training stability

max(denom, 1e-30) — division-by-zero koruması (extreme case).

Sadece tid == 0 yazar (single writer, no race).

6) Phase 4 — dLogits in-place

wgsl
let zl_grad_factor = 2.0 * beta_zl * lse;
for (var v = tid; v < vocab_size; v = v + WG) {
    let p = exp(logits[row_base + v] - row_max) * inv_denom;
    let one_hot = select(0.0, 1.0, v == tgt);
    let tgt_dist = (1.0 - alpha) * one_hot + alpha * inv_V;
    let gr       = (p - tgt_dist) + zl_grad_factor * p;
    logits[row_base + v] = gr * inv_norm;
}

Her thread'in iç döngüsü:

  1. Read logits[row_base + v] → compute p[v] (softmax probability)
  2. Compute target_dist[v] (= label-smoothed target)
  3. Compute gradient = (p - target) + z_loss_term
  4. Write back logits[row_base + v] = gr * inv_norm

Race-free: thread t v ∈ {t, t+WG, t+2WG, ...} element'lerini sahipleniyor. Hiçbir thread başka thread'in slot'una yazmıyor.

Read-before-write: her thread aynı v için önce reads, sonra writes — sırayla. Aynı v'ye iki thread dokunmuyor. Aliasing race yok.


Satır Satır — sum_losses

wgsl
@compute @workgroup_size(256, 1, 1)
fn sum_losses(@builtin(local_invocation_index) tid: u32) {
    var local_sum: f32 = 0.0;
    for (var i = tid; i < n; i = i + WG) {
        local_sum = local_sum + losses[i];
    }
    let total = wg_reduce_sum(tid, local_sum);
    if (tid == 0u) { out_sum[0] = total; }
}

Tek WG, strided loop, reduce. losses[seq_len] array'ini tek scalar'a topla.

Host bu scalar'ı valid_count'a böler → average loss display'i için.


Loss Components — Detay

Label Smoothing (α)

Hard target [0, 0, ..., 1, ..., 0] yerine soft target [α/V, α/V, ..., 1-α+α/V, ..., α/V].

Etki:

  • Confidence calibration — model ekstra confident olmayı öğrenmiyor
  • Gradient'a tüm class'lardan biraz katkı geliyor — generalization yardımı
  • Pretrain'de α=0 yaygın, fine-tune'da α=0.05-0.1

Z-loss (β)

LSE'in büyümesini cezalandırır. β·LSE² minimize edilince LSE → 0'a yaklaşır, yani logits dağılımı 0 etrafında kompakt kalır.

Etki:

  • Training stability (logit overflow nadir)
  • "Logit drift"i engeller (uzun training'de logits sürüklenmeli)

PaLM, Switch Transformer kullandı. Bizim default β=1e-4 (çok hafif).


WGSL-Spesifik Notlar

1. Uniform branch + workgroupBarrier OK

wgsl
if (alpha > 0.0) {
    workgroupBarrier();
    // ...
}

Barrier uniform CF gerektirir. alpha uniform variable → tüm thread'ler aynı branch'e girer, barrier safe. Eğer non-uniform olsaydı deadlock olurdu (bazı thread'ler barrier'da bekler, diğerleri atlar).

2. select(false_val, true_val, cond)

wgsl
let one_hot = select(0.0, 1.0, v == tgt);

Read: "if v == tgt then 1.0 else 0.0".

3. In-place storage aliasing

WebGPU runtime aynı buffer'ı read_write olarak tek binding'de kabul eder. read + read_write overlap'i yasak. Bizim kernel doğru — sadece read_write logits var, ayrı dLogits yok.


Code Review

Bulgu 1: vocab_size = 0 edge case

RiskAçıklama
🟢 yokvocab_size = 0 ise Σ exp = 0, denom = 0, 1/max(denom, 1e-30) = 1e30. Sonra row_max = -∞, log(0+) = -∞lse = -∞ + 0 = -∞. Sonuç loss = -∞. Pratik anlamsız ama crash yok. Production'da vocab=0 hiç olmaz.

Bulgu 2: tid == 0u losses[s] = 0 ignore'da

RiskAçıklama
🟢 yokİdeal. tid==0 single-writer. Diğer thread'ler dLogits sıfırlama yapıyor (paralel).

Bulgu 3: In-place strategy backward correctness

RiskAçıklama
🟢 yokBackward'da dLogits reading yapılır → matmul kernel'larıyla lm_head weight gradient hesaplanır. dLogits logits'in offsetinde. Doğru semantically.

Hızlı Kontrol Listesi

Test SenaryosuDurum
tgt >= vocab row ignored mi?
α=0 fast path?✅ branch atlar
β=0 z-loss yok?✅ multiply by 0 = 0
In-place race condition?✅ per-thread ownership, read-before-write
seq_len = 0 crash?✅ workgroup count = 0, no dispatch
LSE matches numpy reference?⚠ formal test yok

Sonraki

08_cast.md — Mixed precision için f32 ↔ f16 dönüşüm. Optimizer çıkışında çalışır.

WGSL kernel etüdleri · WebGPU üzerinde sıfırdan LLMİstanbul’da Uğur Toprakdeviren tarafından hazırlandı.