cross_entropy ve sum_losses — Fused Loss + dLogits
Dosya: 07_loss.wgsl Pipeline adımı: Forward'ın son adımı + backward'ın başlangıcı. lm_head logits'i alıp loss + gradient'i tek geçişte üretir.
İki kernel:
cross_entropy— her row için loss + dLogits in-placesum_losses— per-row losses → scalar total
Nedir bu ya?
Modeli, her adımda "sıradaki kelime ne?" sorusunu çözen bir öğrenci gibi düşün. Ama bu öğrenci tek bir cevap vermiyor — vocab'daki her kelimeye bir bahis koyuyor: "sıradaki kelime %60 ihtimalle 'ev', %15 'araba', %0.001 'muz'…". Bu bahisler toplamı 1 olan bir olasılık dağılımı. İşte softmax'in yaptığı bu: ham logit skorlarını olasılığa çeviriyor.
Loss da şunu ölçen bir sınav puanı: doğru cevap neydi, ve model o doğru cevaba ne kadar olasılık vermişti? Doğru kelimeye %95 dediyse harika, puan (loss) neredeyse sıfır. Ama doğru kelimeye %2 deyip yanında kendinden emin bir şekilde başka bir şeye %90 dediyse — işte o zaman ağır ceza yer. Cross-entropy aslında "model ne kadar şaşırdı?" sorusunun ölçüsü: -log(doğru kelimeye verilen olasılık). Emin ve haklıysan log neredeyse 0; emin ama haksızsan log patlıyor.
"Log-softmax fused" kısmı pratik bir numara. Normalde önce softmax (olasılıkları üret), sonra log al, sonra doğru olanı seç derdin — üç ayrı adım, taşma riski (exp çok büyük sayılarda patlar). Bunun yerine kernel hepsini tek geçişte, sayısal olarak güvenli şekilde (row_max çıkararak) yapıyor. Üstüne, loss'u hesaplarken backward'ın ihtiyacı olan gradient'i de aynı tabloya yazıyor — iki iş bir oturuşta.
Ve burası önemli: bu kernel forward'ın bittiği, backward'ın başladığı yer. "Ne kadar yanlıştık" sayısını ürettiği an, "o hatayı düzeltmek için her logit'i hangi yöne itmeliyiz" sinyalini de çıkarıyor. Geri yayılımın ilk dominosu burada devriliyor.
Ne Yapar?
cross_entropy
Standard cross-entropy + label smoothing (α) + z-loss (β) fused:
softmax: p[v] = exp(l[v] - row_max) / Σ_v' exp(l[v'] - row_max)
ce_loss[s] = -log p[tgt[s]]
= LSE - l[tgt[s]] where LSE = row_max + log Σ exp(l - row_max)
with label smoothing (α):
loss[s] = LSE - (1-α)·l[tgt] - α·mean(l)
with z-loss (β):
loss[s] += β·LSE²
dLogits[s, v] = (p[v] - target_dist[v]) / norm
+ 2β·LSE·p[v] / norm
where target_dist = (1-α)·one_hot + α/V uniformnorm = (valid token count), host'tan inv_norm = 1/norm gönderilir.
sum_losses
Per-row losses array'ini tek scalar'a indiriyor (1 WG ile).
Niye In-place dLogits?
Backward'da dLogits lm_head'in bir önceki step'i — dLogits[s, v] vocab_size boyutunda ayrı buffer olsa seq × vocab × 4B byte memory tutardı. Örnekle:
- seq=512, vocab=16384 → 32 MB ekstra activation memory
- vocab=50K (LLaMA-3 boyutu) → 100 MB
logits zaten [seq × vocab] aynı boyutta. In-place yaz:
- Phase 1-3: logits'i sadece oku (max, denom, target lookup)
- Phase 4: her thread kendi paylaşına düşen
v'lerin logits'ini okur, dLogits'i hesaplar, aynı offset'e yazar
Read-before-write per thread, no cross-thread alias → race-free.
32-100 MB activation memory tasarrufu. Önemli.
Bind Group ABI
cross_entropy (5 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read_write | logits: array<f32> — [seq × vocab] (in-place dLogits) |
| 1 | storage, read | tgts: array<u32> — [seq] target token IDs |
| 2 | storage, read_write | losses: array<f32> — [seq] per-row loss |
| 3 | uniform | dims: vec4<u32> — (seq_len, vocab_size, _, _) |
| 4 | uniform | params: vec4<f32> — (inv_norm, alpha, beta_zl, inv_V) |
sum_losses (3 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | losses |
| 1 | storage, read_write | out_sum: array<f32> — [1] |
| 2 | uniform | n: u32 |
Dispatch Şekli
cross_entropy
workgroup_size: 256
grid: (seq_len, 1, 1)1 WG = 1 row (1 query position'un loss'u + dLogits'i).
sum_losses
workgroup_size: 256
grid: (1, 1, 1) ← TEK workgroupSingle WG, strided loop ile tüm seq_len losses'i topla.
Satır Satır — cross_entropy
1) Setup + ignore check
let s = wgid.x;
let row_base = s * vocab_size;
let tgt = tgts[s];
if (tgt >= vocab_size) {
if (tid == 0u) { losses[s] = 0.0; }
for (var v = tid; v < vocab_size; v = v + WG) {
logits[row_base + v] = 0.0;
}
return;
}Out-of-range target (örn padding token, vocab dışı) → loss=0 ve dLogits=0. Bu row gradient'a katkı sağlamaz. Padding-aware training için.
2) Phase 1 — Row max
var local_max: f32 = NEG_INF;
for (var v = tid; v < vocab_size; v = v + WG) {
local_max = max(local_max, logits[row_base + v]);
}
let row_max = wg_reduce_max(tid, local_max);Strided loop + workgroup reduce. vocab_size = 16384, WG=256 → her thread 64 element işler.
row_max numerical stability için. exp(l - row_max) her zaman ≤ 1, overflow yok.
3) Phase 2a — Denom (sum of exp)
var local_sum: f32 = 0.0;
for (var v = tid; v < vocab_size; v = v + WG) {
local_sum = local_sum + exp(logits[row_base + v] - row_max);
}
let denom = wg_reduce_sum(tid, local_sum);denom = Σ exp(l - max). Standard log-sum-exp setup.
4) Phase 2b — Sum of logits (label smoothing only)
var sum_l: f32 = 0.0;
if (alpha > 0.0) {
workgroupBarrier();
var local_l: f32 = 0.0;
for (var v = tid; v < vocab_size; v = v + WG) {
local_l = local_l + logits[row_base + v];
}
sum_l = wg_reduce_sum(tid, local_l);
}alpha = 0 ise tamamen atla — seq_len × vocab_size extra read kazanılır. alpha uniform → tüm thread'ler aynı yola gider, branch divergence yok. Barrier safe.
mean_l = sum_l / V label smoothing'in regularization terimi.
5) Phase 3 — Loss
let inv_denom = 1.0 / max(denom, 1e-30);
let lse = row_max + log(max(denom, 1e-30));
let mean_l = sum_l * inv_V;
if (tid == 0u) {
let l_t = logits[row_base + tgt];
let ce = lse - (1.0 - alpha) * l_t - alpha * mean_l;
let zl = beta_zl * lse * lse;
losses[s] = ce + zl;
}lse = log Σ exp(l)— numerically stablece = -log p[tgt]simplified:lse - l[tgt](when α=0)- Label smoothing:
(1-α)·l_t + α·mean_lile mix - Z-loss:
β·lse²LSE'i 0'a doğru regularize, training stability
max(denom, 1e-30) — division-by-zero koruması (extreme case).
Sadece tid == 0 yazar (single writer, no race).
6) Phase 4 — dLogits in-place
let zl_grad_factor = 2.0 * beta_zl * lse;
for (var v = tid; v < vocab_size; v = v + WG) {
let p = exp(logits[row_base + v] - row_max) * inv_denom;
let one_hot = select(0.0, 1.0, v == tgt);
let tgt_dist = (1.0 - alpha) * one_hot + alpha * inv_V;
let gr = (p - tgt_dist) + zl_grad_factor * p;
logits[row_base + v] = gr * inv_norm;
}Her thread'in iç döngüsü:
- Read
logits[row_base + v]→ computep[v](softmax probability) - Compute
target_dist[v](= label-smoothed target) - Compute
gradient = (p - target) + z_loss_term - Write back
logits[row_base + v] = gr * inv_norm
Race-free: thread t v ∈ {t, t+WG, t+2WG, ...} element'lerini sahipleniyor. Hiçbir thread başka thread'in slot'una yazmıyor.
Read-before-write: her thread aynı v için önce reads, sonra writes — sırayla. Aynı v'ye iki thread dokunmuyor. Aliasing race yok.
Satır Satır — sum_losses
@compute @workgroup_size(256, 1, 1)
fn sum_losses(@builtin(local_invocation_index) tid: u32) {
var local_sum: f32 = 0.0;
for (var i = tid; i < n; i = i + WG) {
local_sum = local_sum + losses[i];
}
let total = wg_reduce_sum(tid, local_sum);
if (tid == 0u) { out_sum[0] = total; }
}Tek WG, strided loop, reduce. losses[seq_len] array'ini tek scalar'a topla.
Host bu scalar'ı valid_count'a böler → average loss display'i için.
Loss Components — Detay
Label Smoothing (α)
Hard target [0, 0, ..., 1, ..., 0] yerine soft target [α/V, α/V, ..., 1-α+α/V, ..., α/V].
Etki:
- Confidence calibration — model ekstra confident olmayı öğrenmiyor
- Gradient'a tüm class'lardan biraz katkı geliyor — generalization yardımı
- Pretrain'de α=0 yaygın, fine-tune'da α=0.05-0.1
Z-loss (β)
LSE'in büyümesini cezalandırır. β·LSE² minimize edilince LSE → 0'a yaklaşır, yani logits dağılımı 0 etrafında kompakt kalır.
Etki:
- Training stability (logit overflow nadir)
- "Logit drift"i engeller (uzun training'de logits sürüklenmeli)
PaLM, Switch Transformer kullandı. Bizim default β=1e-4 (çok hafif).
WGSL-Spesifik Notlar
1. Uniform branch + workgroupBarrier OK
if (alpha > 0.0) {
workgroupBarrier();
// ...
}Barrier uniform CF gerektirir. alpha uniform variable → tüm thread'ler aynı branch'e girer, barrier safe. Eğer non-uniform olsaydı deadlock olurdu (bazı thread'ler barrier'da bekler, diğerleri atlar).
2. select(false_val, true_val, cond)
let one_hot = select(0.0, 1.0, v == tgt);Read: "if v == tgt then 1.0 else 0.0".
3. In-place storage aliasing
WebGPU runtime aynı buffer'ı read_write olarak tek binding'de kabul eder. read + read_write overlap'i yasak. Bizim kernel doğru — sadece read_write logits var, ayrı dLogits yok.
Code Review
Bulgu 1: vocab_size = 0 edge case
| Risk | Açıklama |
|---|---|
| 🟢 yok | vocab_size = 0 ise Σ exp = 0, denom = 0, 1/max(denom, 1e-30) = 1e30. Sonra row_max = -∞, log(0+) = -∞ → lse = -∞ + 0 = -∞. Sonuç loss = -∞. Pratik anlamsız ama crash yok. Production'da vocab=0 hiç olmaz. |
Bulgu 2: tid == 0u losses[s] = 0 ignore'da
| Risk | Açıklama |
|---|---|
| 🟢 yok | İdeal. tid==0 single-writer. Diğer thread'ler dLogits sıfırlama yapıyor (paralel). |
Bulgu 3: In-place strategy backward correctness
| Risk | Açıklama |
|---|---|
| 🟢 yok | Backward'da dLogits reading yapılır → matmul kernel'larıyla lm_head weight gradient hesaplanır. dLogits logits'in offsetinde. Doğru semantically. |
Hızlı Kontrol Listesi
| Test Senaryosu | Durum |
|---|---|
tgt >= vocab row ignored mi? | ✅ |
| α=0 fast path? | ✅ branch atlar |
| β=0 z-loss yok? | ✅ multiply by 0 = 0 |
| In-place race condition? | ✅ per-thread ownership, read-before-write |
seq_len = 0 crash? | ✅ workgroup count = 0, no dispatch |
| LSE matches numpy reference? | ⚠ formal test yok |
Sonraki
08_cast.md — Mixed precision için f32 ↔ f16 dönüşüm. Optimizer çıkışında çalışır.