gelu_inplace and swiglu_combine — FFN Activations
File: 06_activation.wgsl Pipeline step: FFN intermediate computation — runs the gate output through the activation and combines it with the up output.
Two kernels:
gelu_inplace— GeLU activation, in-placeswiglu_combine— Fusedsilu(gate) * up
Wait, what is this?
You've got raw numbers coming out of a matmul — some positive, some negative, all sitting on a straight line. If you pass them along to the next layer untouched, all your stacked layers collapse into one big multiplier; no matter how deep the model gets, it still draws a straight line. An activation function is the thing that breaks that flatness: it adds a bend, so the model can learn curvy, non-linear stuff.
The most intuitive way to picture it: an activation is like a dimmer switch. Not a hard on/off toggle — it gently decides "how much of this signal do I let through?" GeLU does exactly that: instead of cutting small negatives off entirely, it lets a little through, and leaves large positives almost as-is. Think of it as the smoothed, more polite version of the old ReLU ("if it's below 0, just chop it").
SwiGLU takes it one step further. Instead of a single signal, you produce two parallel signals: one is up (the actual content), the other is gate (the gatekeeper). You run the gate through a dimmer (SiLU) and then multiply it by up — so one signal decides how much of the other gets through. It's just like a channel on an audio mixer: the first fader is the sound itself, the second fader is the knob that ducks or boosts it. This learnable gate gives the model the flexibility to say "play up this feature in this context, suppress that one"; LLaMA, Mistral, and Gemma all rely on it.
What the kernels actually do is simple: apply a small math formula per element. The real subtlety isn't there — functions like exp() can blow up to infinity on large inputs and produce NaN. So the naive formula isn't enough; as you'll see below, a sign-based branch is used to build a numerically stable SiLU.
What Does It Do?
GeLU
Gaussian Error Linear Unit. Smooth, non-monotone (there is a tiny dip at negatives):
gelu(x) ≈ 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³)))This is the tanh approximation. Exact GeLU is x · Φ(x) where Φ is the standard normal CDF — expensive. The tanh form is both fast and ~99% accurate.
SwiGLU
The heart of LLaMA-style FFN. Two separate projections (gate, up) → element-wise multiply of up by SiLU(gate):
silu(x) = x · sigmoid(x)
swiglu_combine(gate, up): out[i] = silu(gate[i]) · up[i]The full SwiGLU FFN:
hidden = swiglu_combine(W_gate(x), W_up(x)) ← 2 matmuls + this kernel
out = W_down(hidden) ← 1 matmulvs. GeLU FFN:
hidden = gelu(W_up(x)) ← 1 matmul + gelu
out = W_down(hidden) ← 1 matmulSwiGLU costs 1 extra matmul but delivers better quality. LLaMA, Mistral, Falcon, and Gemma all use SwiGLU. ~5-10% lower loss.
The Math
GeLU
gelu(x) = 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³)))
with constants:
√(2/π) ≈ 0.7978845608
0.044715 (paper-derived coefficient)SiLU (a.k.a. Swish-1)
silu(x) = x · sigmoid(x) = x / (1 + exp(-x))SwiGLU combine
swiglu_combine(g, u)[i] = silu(g[i]) · u[i]
= (g[i] / (1 + exp(-g[i]))) · u[i]gate (g) and up (u) come from matmuls — the typical FFN intermediate dim is 4× d_model. So both have size d_ff=3072.
Bind Group ABI
gelu_inplace (2 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | x: array<f32> — [seq × d_ff] |
| 1 | uniform | n: u32 — total element count |
swiglu_combine (4 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | gate: array<f32> |
| 1 | storage, read | up: array<f32> |
| 2 | storage, read_write | hidden: array<f32> — output |
| 3 | uniform | n: u32 |
Dispatch Shape
workgroup_size: 256
threads: ceil(n / 256) workgroups × 256n = seq × d_ff (for example 512 × 3072 = 1.57M). 6144 WG.
One thread = one element. Fully parallel, no shared memory, no barriers.
Line by Line
gelu_inplace
@compute @workgroup_size(256, 1, 1)
fn gelu_inplace(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) nwg: vec3<u32>) {
let i = flat_id(gid, nwg);
if (i >= n) { return; }
x[i] = gelu(x[i]);
}Trivial — flat_id, bounds check, in-place transform.
gelu() helper:
fn gelu(x: f32) -> f32 {
let xc = clamp(x, -100.0, 100.0);
let inner = 0.7978845608 * (xc + 0.044715 * xc * xc * xc);
return 0.5 * xc * (1.0 + tanh(inner));
}Why clamp(-100, 100)? x³ overflows on very large inputs. x=100 → x³ = 1M, x=300 → x³ = 27M, x=1000 → x³ = 1e9, x=10000 → x³ = 1e12. f32 max is ~3.4e38, still safe. But 0.044715 * x³ then becomes the tanh argument → tanh(very_large) = ±1 → the function is already saturated. The clamp is an early exit that prevents NaN propagation.
tanh() is a WGSL built-in. A native instruction on Apple GPU hardware.
swiglu_combine
@compute @workgroup_size(256, 1, 1)
fn swiglu_combine(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) nwg: vec3<u32>) {
let i = flat_id(gid, nwg);
if (i >= n) { return; }
hidden[i] = silu(gate[i]) * up[i];
}silu() helper:
fn silu(x: f32) -> f32 {
var sig: f32;
if (x >= 0.0) {
sig = 1.0 / (1.0 + exp(-x));
} else {
let e = exp(x);
sig = e / (1.0 + e);
}
return x * sig;
}- Branched Stabilization: Instead of the old approach's
clamp(x, -50, 50), the code branches on the sign ofx. This keeps the argument ofexp()always polarized negative (exp(-x)orexp(x)). It fully prevents, at the hardware level,exp()from seeing a positive argument and producingInf(infinity), then formingNaNvia anInf/Infindeterminacy (the underflow case rounds directly to zero and stays stable). With this approach there is no need for clamp bounds, and thesilu(x) -> xasymptote is preserved flawlessly on large inputs.
In practice, the SwiGLU input (the matmul output) sits in the ~[-3, 3] range during normal training. This structure prevents the pipeline from being contaminated with NaN even during unusual gradient spikes.
Why Fused swiglu_combine?
Classic PyTorch:
gate_out = gate_proj(x)
up_out = up_proj(x)
silu_out = F.silu(gate_out)
hidden = silu_out * up_out= 4 separate kernel calls (silu in-place + multiply separately, or silu + multiply fused).
Our swiglu_combine:
hidden[i] = silu(gate[i]) · up[i]= 1 kernel call. Memory traffic:
- Classic: silu in-place (1 read + 1 write of gate) + multiply (2 reads + 1 write) = 5 ops
- Fused: 2 reads (gate, up) + 1 write (hidden) = 3 ops → 40% bandwidth savings
Plus one less dispatch overhead.
WGSL-Specific Notes
1. tanh() and exp() built-ins
Guaranteed by the WGSL spec. Optimizes down to hardware (native on Apple GPU).
2. clamp(x, lo, hi) — saturated arithmetic
The behavior of clamp(NaN, a, b) is unspecified in the WGSL spec. On Apple Metal it's "either NaN or a bound, usually a bound". It doesn't matter for our use (gradient finiteness is already checked in backward).
3. In-place vs separate output
The name gelu_inplace implies in-place — it really does overwrite the input. The old value cannot be recovered. If backward needs it, we have to keep an extra copy buffer in forward (our code does this — the pre_act buffer).
swiglu_combine writes a separate output buffer (hidden). For backward, gate and up stay preserved — no extra save.
Performance
From the profile snapshot:
swiglu_combine: 12 layers × ~186 µs = 2.2 ms total = 0.4% of the step
Very small. Element-wise, memory-bound, not a hot kernel.
Code Review
Finding 1: GeLU clamp ±100 is perhaps unnecessary
| Risk | Explanation |
|---|---|
| 🟢 none | x=100 → x³=1M → 0.044715 × 1M = 44715 → tanh(44715) ≈ 1.0 (saturated). Clamp is safety, never hit on the hot path. |
Finding 2: gelu_inplace needs pre-act for backward
| Risk | Explanation |
|---|---|
| 🟢 (architectural) | When backward GeLU computes dx = ∂gelu/∂x · dy, it needs to know the forward x value. Since in-place GeLU overwrites x, backward needs a copy of it — this is held host-side in the pre_act buffer. Extra memory but unavoidable. |
Finding 3: SwiGLU has no pre-activation save
| Risk | Explanation |
|---|---|
| 🟢 none | SwiGLU backward needs the silu derivative. gate[i] is already a separate buffer, preserved. up[i] is preserved too. No save problem. |
Quick Checklist
| Test Scenario | Status |
|---|---|
Is gelu(0) = 0? | ✅ formula |
gelu(very_large) no overflow? | ✅ clamp |
Is silu(0) = 0? | ✅ formula |
Is swiglu_combine aliasing-safe? (gate/up/hidden separate buffers) | ✅ runtime check |
Does n = 0 not crash? | ✅ bounds check passes on the first thread |
| GeLU vs PyTorch ref same? | ⚠ no formal comparison test, but loss curves are reasonable |
Next
07_loss.md — Cross-entropy + sum_losses. The last step of forward, the start of backward.