ffn_gelu_backward and ffn_swiglu_backward — FFN Activation Backward
File: 09_backward_ffn.wgsl Pipeline step: The element-wise part of the backward FFN. The GEMM part is handled by the
matmul_t/_atkernels.
Two kernels:
ffn_gelu_backward— GeLU activation backwardffn_swiglu_backward— SwiGLU combine backward
Wait, what is this?
In the forward pass an activation function (GeLU, SiLU) bends a number around: it squashes some values, lets others through. All you say there is "input was this, output was that." But the backward pass asks something else: "if I'd nudged the input a tiny bit, how much would the output have moved?" The answer to that "how much" is the derivative — the slope of the function at that exact point.
Think of it like a gas pedal. Press the pedal 1 mm — how much faster does the car go? It depends on where the pedal already is. Near idle, 1 mm does basically nothing (slope ~0); halfway down, the same 1 mm gives you a real jolt (steep slope). The derivative is exactly that: "at the current input, how much does a tiny tap show up in the output." Backward just scales the gradient coming from above by that factor — grad_in = grad_out × slope. That's all the chain rule really is.
GeLU has one slope, so you multiply by it directly. SwiGLU is a bit sneakier: its output is the product of two separate inputs (gate and up), so you work out two slopes and split the gradient between them. Both run in the same kernel, element by element, each thread oblivious to its neighbors — embarrassingly parallel.
The one subtle bit is how you compute the slope. The sigmoid has an e^x in it, and if x is large that blows up to infinity (overflow → Inf, then everything turns to NaN). The fix: check the sign and rewrite the formula so exp() only ever sees a non-positive argument. That small but vital trick is what we mean by "stable sigmoid" below.
What Does It Do?
GeLU backward
Forward: hidden = gelu(pre_gelu). Backward:
grad_pre[i] = grad_hidden[i] · gelu'(pre[i])SwiGLU backward
Forward: hidden = silu(gate) · up. Backward (chain rule):
grad_gate[i] = grad_hidden[i] · up[i] · silu'(gate[i])
grad_up[i] = grad_hidden[i] · silu(gate[i])Two separate gradient computations for two separate upstream tensors — in parallel within the same kernel.
Derivatives
GeLU' (tanh approximation)
Forward:
gelu(x) = 0.5x · (1 + tanh(inner))
inner = √(2/π) · (x + 0.044715·x³)Derivative (chain + product rule):
gelu'(x) = 0.5 · (1 + tanh(inner)) + 0.5·x · sech²(inner) · d(inner)/dx
sech²(z) = 1 - tanh²(z)
d(inner)/dx = √(2/π) · (1 + 3·0.044715·x²)Code:
fn gelu_derivative(x_in: f32) -> f32 {
// d/dx [0.5 * x * (1 + tanh(inner))] inner = sqrt(2/pi)*(x + 0.044715*x^3)
let x = clamp(x_in, -100.0, 100.0);
let sqrt_2_over_pi = 0.7978845608;
let coeff = 0.044715;
let x2 = x * x;
let x3 = x2 * x;
let inner = sqrt_2_over_pi * (x + coeff * x3);
let t = tanh(inner);
let sech2 = 1.0 - t * t;
let d_inner = sqrt_2_over_pi * (1.0 + 3.0 * coeff * x2);
return 0.5 * (1.0 + t) + 0.5 * x * sech2 * d_inner;
}SiLU'
Forward: silu(x) = x · σ(x) where σ(x) = 1/(1+e^-x) sigmoid.
Derivative:
silu'(x) = σ(x) + x · σ(x) · (1 - σ(x))
= σ(x) · (1 + x · (1 - σ(x)))Code:
fn sigmoid(x: f32) -> f32 {
// Stable sigmoid — branch on sign so exp() only ever sees non-positive argument
if (x >= 0.0) {
return 1.0 / (1.0 + exp(-x));
}
let e = exp(x);
return e / (1.0 + e);
}
fn silu_derivative(x: f32) -> f32 {
let s = sigmoid(x);
return s * (1.0 + x * (1.0 - s));
}Bind Group ABI
ffn_gelu_backward (4 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | grad_hidden: array<f32> — upstream gradient (from the FFN's final matmul) |
| 1 | storage, read | pre_gelu: array<f32> — pre-activation saved during forward |
| 2 | storage, read_write | grad_pre: array<f32> — output, flows to the next matmul backward |
| 3 | uniform | n: u32 |
ffn_swiglu_backward (6 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | grad_hidden |
| 1 | storage, read | gate — forward gate output |
| 2 | storage, read | up — forward up output |
| 3 | storage, read_write | grad_gate |
| 4 | storage, read_write | grad_up |
| 5 | uniform | n: u32 |
Dispatch Shape
workgroup_size: 256
threads: ceil(n / 256) workgroups
n = seq × d_ff (e.g. 512 × 3072 = 1.57M)Element-wise, no shared memory, no barriers.
Line by Line
ffn_gelu_backward
let i = flat_id(gid, nwg);
if (i >= n) { return; }
let g = grad_hidden[i];
let pg = pre_gelu[i];
let r = g * gelu_derivative(pg);
grad_pre[i] = nan_guard(r);Trivial chain rule. nan_guard is a finite-check — the pre-activation may be NaN (rare), so block propagation to infinity.
ffn_swiglu_backward
let i = flat_id(gid, nwg);
if (i >= n) { return; }
let g = grad_hidden[i];
let gt = gate[i];
let u = up[i];
let s = silu(gt);
let dg = g * u * silu_derivative(gt);
let du = g * s;
grad_gate[i] = nan_guard(dg);
grad_up[i] = nan_guard(du);Two gradients in parallel:
grad_gate = g · u · silu'(gate)grad_up = g · silu(gate)
silu(gate) and silu'(gate) are computed twice for the same gate[i] value (the sigmoid is inside both). A small amount of waste, but fp32 ALU is free.
Why Is pre_gelu Saved?
GeLU is non-linear. To compute the derivative gelu'(x) in the backward pass, you need to know the x value from the forward pass. But during the forward pass the gelu_inplace kernel overwrites x (in-place).
Solution: during the forward pass, store a copy into the pre_gelu buffer. The backward pass reads it.
Alternative: recompute x in the backward pass via gelu_inverse(hidden). But GeLU is non-injective — inversion is impossible.
Memory cost: seq × d_ff × 4B of extra activation. seq=512, d_ff=3072 → 6 MB per layer. 12 layers → 72 MB. Significant but necessary.
SwiGLU does not have this problem — gate and up are already preserved in separate buffers (they were matmul outputs, and the post-activation was not written over them).
WGSL-Specific Notes
1. tanh() and exp() — built-in
Hardware-accelerated. The Apple GPU fast_tanh instruction is ~3 cycles.
2. clamp(x, -100, 100) and the Branching Stable Sigmoid
The clamp(x, -100, 100) guard is retained for the GeLU input (to keep the x³ growth under control). However, the old clamp(x, -50, 50) has been removed entirely from the sigmoid/silu computation. Instead, with the x >= 0.0 check, either exp(-x) or exp(x) is computed. Since the exp() argument is always guaranteed to be negative or zero, exponential overflow (overflow → Inf) is prevented at the hardware level and the clamp is no longer needed. The remaining underflow cases resolve directly and stably to 0.0.
3. Re-using silu() and sigmoid() helpers
The forward kernel and the backward kernel use the same helpers. Thanks to WGSL preamble injection (each kernel module compiled separately, helpers injected via a file-local preamble rather than into 00_shared.wgsl). In our code, silu and silu_derivative live in the file preamble of the same file.
Code Review
Finding 1: silu(gt) and silu_derivative(gt) compute the same sigmoid(gt)
| Risk | Description |
|---|---|
| 🟢 minor | Optimization opportunity: let s = sigmoid(gt); let silu_v = gt * s; let silu_d = s * (1 + gt*(1-s));. Sigmoid once. ~5% bandwidth-free savings. In practice fp32 ALU is free, the difference is imperceptible. |
Finding 2: nan_guard in both kernels
| Risk | Description |
|---|---|
| 🟢 none | Defensive. gelu_derivative or silu_derivative may produce NaN on extreme input; rather than polluting the downstream gradient, pull it to 0. |
Finding 3: No mixed precision in these kernels
| Risk | Description |
|---|---|
| 🟢 none (architectural) | fp32 is the backward standard. pre_gelu, gate, and up come from fp32 activations in the forward pass. Mixed precision is only on the weight side, not on the activation side. |
Quick Checklist
| Test Scenario | Status |
|---|---|
Is gelu_derivative(0) ≈ 0.5? | ✅ formula |
Is silu_derivative(0) ≈ 0.5? | ✅ formula |
| Is the derivative positive for negative input? | ✅ |
Is pre_gelu saved during forward? | ✅ host-side |
Does nan_guard work on NaN injection? | ✅ self-check |
| GeLU forward+backward → identity gradient (small steps)? | ⚠ no formal test |
Next
10_backward_attention.md — the most complex kernel in the pipeline. 3 variants: streaming, split-short, split-dKdV.