llm.istanbul·Study
TR EN
Workbench →

cross_entropy and sum_losses — Fused Loss + dLogits

File: 07_loss.wgsl Pipeline step: The final step of forward + the start of backward. Takes the lm_head logits and produces loss + gradient in a single pass.

Two kernels:

  • cross_entropy — loss + dLogits in-place for each row
  • sum_losses — per-row losses → scalar total

Wait, what is this?

Picture the model as a student answering "what's the next word?" at every step. But this student doesn't give a single answer — it places a bet on every word in the vocab: "the next word is 60% 'house', 15% 'car', 0.001% 'banana'…". Those bets are a probability distribution that sums to 1. That's exactly what softmax does: it turns raw logit scores into probabilities.

Loss is the exam score that asks: what was the right answer, and how much probability did the model put on it? If it said 95% on the correct word, great, the score (loss) is near zero. But if it gave the right word a measly 2% while confidently betting 90% on something else — that's when it gets hit with a heavy penalty. Cross-entropy is really a measure of "how surprised was the model?": -log(probability assigned to the correct word). Confident and right, and the log is near 0; confident but wrong, and the log blows up.

The "log-softmax fused" part is a practical trick. Normally you'd run softmax first (produce the probabilities), then take the log, then pick the right one — three separate steps, with overflow risk (exp blows up on large numbers). Instead the kernel does all of it in one pass, in a numerically safe way (by subtracting row_max). On top of that, while computing the loss it also writes the gradient that backward needs straight into the same table — two jobs in one sitting.

And here's the key bit: this kernel is where forward ends and backward begins. The moment it produces the "how wrong were we" number, it also emits the "which direction should we push each logit to fix that mistake" signal. The first domino of backprop tips over right here.


What Does It Do?

cross_entropy

Standard cross-entropy + label smoothing (α) + z-loss (β) fused:

softmax: p[v] = exp(l[v] - row_max) / Σ_v' exp(l[v'] - row_max)
ce_loss[s] = -log p[tgt[s]]
            = LSE - l[tgt[s]]                  where LSE = row_max + log Σ exp(l - row_max)

with label smoothing (α):
    loss[s] = LSE - (1-α)·l[tgt] - α·mean(l)

with z-loss (β):
    loss[s] += β·LSE²

dLogits[s, v] = (p[v] - target_dist[v]) / norm
              + 2β·LSE·p[v] / norm
              where target_dist = (1-α)·one_hot + α/V uniform

norm = (valid token count), the host sends inv_norm = 1/norm.

sum_losses

Reduces the per-row losses array to a single scalar (with 1 WG).


Why In-place dLogits?

In backward, dLogits is the prior step of lm_head — if dLogits[s, v] were a separate buffer of size vocab_size, it would hold seq × vocab × 4B bytes of memory. For example:

  • seq=512, vocab=16384 → 32 MB of extra activation memory
  • vocab=50K (LLaMA-3 size) → 100 MB

logits is already [seq × vocab], the same size. Write in-place:

  1. Phase 1-3: only read logits (max, denom, target lookup)
  2. Phase 4: each thread reads the logits for the vs assigned to its share, computes dLogits, and writes to the same offset

Read-before-write per thread, no cross-thread alias → race-free.

32-100 MB of activation memory saved. Significant.


Bind Group ABI

cross_entropy (5 bindings)

BindingTypeDetail
0storage, read_writelogits: array<f32>[seq × vocab] (in-place dLogits)
1storage, readtgts: array<u32>[seq] target token IDs
2storage, read_writelosses: array<f32>[seq] per-row loss
3uniformdims: vec4<u32>(seq_len, vocab_size, _, _)
4uniformparams: vec4<f32>(inv_norm, alpha, beta_zl, inv_V)

sum_losses (3 bindings)

BindingTypeDetail
0storage, readlosses
1storage, read_writeout_sum: array<f32>[1]
2uniformn: u32

Dispatch Shape

cross_entropy

workgroup_size: 256
grid:           (seq_len, 1, 1)

1 WG = 1 row (the loss + dLogits of 1 query position).

sum_losses

workgroup_size: 256
grid:           (1, 1, 1)        ← SINGLE workgroup

Single WG, sum all seq_len losses with a strided loop.


Line by Line — cross_entropy

1) Setup + ignore check

wgsl
let s = wgid.x;
let row_base = s * vocab_size;
let tgt = tgts[s];

if (tgt >= vocab_size) {
    if (tid == 0u) { losses[s] = 0.0; }
    for (var v = tid; v < vocab_size; v = v + WG) {
        logits[row_base + v] = 0.0;
    }
    return;
}

Out-of-range target (e.g. padding token, outside vocab) → loss=0 and dLogits=0. This row contributes nothing to the gradient. For padding-aware training.

2) Phase 1 — Row max

wgsl
var local_max: f32 = NEG_INF;
for (var v = tid; v < vocab_size; v = v + WG) {
    local_max = max(local_max, logits[row_base + v]);
}
let row_max = wg_reduce_max(tid, local_max);

Strided loop + workgroup reduce. vocab_size = 16384, WG=256 → each thread processes 64 elements.

row_max is for numerical stability. exp(l - row_max) is always ≤ 1, no overflow.

3) Phase 2a — Denom (sum of exp)

wgsl
var local_sum: f32 = 0.0;
for (var v = tid; v < vocab_size; v = v + WG) {
    local_sum = local_sum + exp(logits[row_base + v] - row_max);
}
let denom = wg_reduce_sum(tid, local_sum);

denom = Σ exp(l - max). Standard log-sum-exp setup.

4) Phase 2b — Sum of logits (label smoothing only)

wgsl
var sum_l: f32 = 0.0;
if (alpha > 0.0) {
    workgroupBarrier();
    var local_l: f32 = 0.0;
    for (var v = tid; v < vocab_size; v = v + WG) {
        local_l = local_l + logits[row_base + v];
    }
    sum_l = wg_reduce_sum(tid, local_l);
}

If alpha = 0 skip entirely — saving seq_len × vocab_size extra reads. alpha is uniform → all threads take the same path, no branch divergence. Barrier safe.

mean_l = sum_l / V is the regularization term of label smoothing.

5) Phase 3 — Loss

wgsl
let inv_denom = 1.0 / max(denom, 1e-30);
let lse       = row_max + log(max(denom, 1e-30));
let mean_l    = sum_l * inv_V;

if (tid == 0u) {
    let l_t  = logits[row_base + tgt];
    let ce   = lse - (1.0 - alpha) * l_t - alpha * mean_l;
    let zl   = beta_zl * lse * lse;
    losses[s] = ce + zl;
}
  • lse = log Σ exp(l) — numerically stable
  • ce = -log p[tgt] simplified: lse - l[tgt] (when α=0)
  • Label smoothing: mix with (1-α)·l_t + α·mean_l
  • Z-loss: β·lse² regularizes LSE toward 0, training stability

max(denom, 1e-30) — division-by-zero protection (extreme case).

Only tid == 0 writes (single writer, no race).

6) Phase 4 — dLogits in-place

wgsl
let zl_grad_factor = 2.0 * beta_zl * lse;
for (var v = tid; v < vocab_size; v = v + WG) {
    let p = exp(logits[row_base + v] - row_max) * inv_denom;
    let one_hot = select(0.0, 1.0, v == tgt);
    let tgt_dist = (1.0 - alpha) * one_hot + alpha * inv_V;
    let gr       = (p - tgt_dist) + zl_grad_factor * p;
    logits[row_base + v] = gr * inv_norm;
}

The inner loop of each thread:

  1. Read logits[row_base + v] → compute p[v] (softmax probability)
  2. Compute target_dist[v] (= label-smoothed target)
  3. Compute gradient = (p - target) + z_loss_term
  4. Write back logits[row_base + v] = gr * inv_norm

Race-free: thread t owns the elements v ∈ {t, t+WG, t+2WG, ...}. No thread writes to another thread's slot.

Read-before-write: each thread first reads, then writes for the same v — in order. No two threads touch the same v. No aliasing race.


Line by Line — sum_losses

wgsl
@compute @workgroup_size(256, 1, 1)
fn sum_losses(@builtin(local_invocation_index) tid: u32) {
    var local_sum: f32 = 0.0;
    for (var i = tid; i < n; i = i + WG) {
        local_sum = local_sum + losses[i];
    }
    let total = wg_reduce_sum(tid, local_sum);
    if (tid == 0u) { out_sum[0] = total; }
}

Single WG, strided loop, reduce. Sum the losses[seq_len] array into a single scalar.

The host divides this scalar by valid_count → for the average loss display.


Loss Components — Detail

Label Smoothing (α)

Instead of the hard target [0, 0, ..., 1, ..., 0], a soft target [α/V, α/V, ..., 1-α+α/V, ..., α/V].

Effect:

  • Confidence calibration — the model does not learn to be overly confident
  • Some contribution to the gradient comes from all classes — helps generalization
  • α=0 is common in pretrain, α=0.05-0.1 in fine-tune

Z-loss (β)

Penalizes the growth of LSE. When β·LSE² is minimized, LSE approaches 0, i.e. the logit distribution stays compact around 0.

Effect:

  • Training stability (logit overflow is rare)
  • Prevents "logit drift" (logits tend to drift over long training)

PaLM, Switch Transformer used it. Our default β=1e-4 (very light).


WGSL-Specific Notes

1. Uniform branch + workgroupBarrier OK

wgsl
if (alpha > 0.0) {
    workgroupBarrier();
    // ...
}

A barrier requires uniform CF. alpha is a uniform variable → all threads enter the same branch, barrier safe. If it were non-uniform there would be a deadlock (some threads wait at the barrier, others skip it).

2. select(false_val, true_val, cond)

wgsl
let one_hot = select(0.0, 1.0, v == tgt);

Read: "if v == tgt then 1.0 else 0.0".

3. In-place storage aliasing

The WebGPU runtime accepts the same buffer as read_write in a single binding. A read + read_write overlap is forbidden. Our kernel is correct — there is only read_write logits, no separate dLogits.


Code Review

Finding 1: vocab_size = 0 edge case

RiskDescription
🟢 noneIf vocab_size = 0 then Σ exp = 0, denom = 0, 1/max(denom, 1e-30) = 1e30. Then row_max = -∞, log(0+) = -∞lse = -∞ + 0 = -∞. Result loss = -∞. Practically meaningless but no crash. In production vocab=0 never happens.

Finding 2: tid == 0u losses[s] = 0 in ignore

RiskDescription
🟢 noneIdeal. tid==0 single-writer. The other threads zero out dLogits (in parallel).

Finding 3: In-place strategy backward correctness

RiskDescription
🟢 noneIn backward, dLogits is read → the lm_head weight gradient is computed with the matmul kernels. dLogits is at the offset of logits. Semantically correct.

Quick Checklist

Test ScenarioStatus
Is the tgt >= vocab row ignored?
α=0 fast path?✅ skips the branch
β=0 no z-loss?✅ multiply by 0 = 0
In-place race condition?✅ per-thread ownership, read-before-write
seq_len = 0 crash?✅ workgroup count = 0, no dispatch
LSE matches numpy reference?⚠ no formal test

Next

08_cast.md — f32 ↔ f16 conversion for mixed precision. Runs at the optimizer output.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.