llm.istanbul·Study
TR EN
Workbench →

gelu_inplace and swiglu_combine — FFN Activations

File: 06_activation.wgsl Pipeline step: FFN intermediate computation — runs the gate output through the activation and combines it with the up output.

Two kernels:

  • gelu_inplace — GeLU activation, in-place
  • swiglu_combine — Fused silu(gate) * up

Wait, what is this?

You've got raw numbers coming out of a matmul — some positive, some negative, all sitting on a straight line. If you pass them along to the next layer untouched, all your stacked layers collapse into one big multiplier; no matter how deep the model gets, it still draws a straight line. An activation function is the thing that breaks that flatness: it adds a bend, so the model can learn curvy, non-linear stuff.

The most intuitive way to picture it: an activation is like a dimmer switch. Not a hard on/off toggle — it gently decides "how much of this signal do I let through?" GeLU does exactly that: instead of cutting small negatives off entirely, it lets a little through, and leaves large positives almost as-is. Think of it as the smoothed, more polite version of the old ReLU ("if it's below 0, just chop it").

SwiGLU takes it one step further. Instead of a single signal, you produce two parallel signals: one is up (the actual content), the other is gate (the gatekeeper). You run the gate through a dimmer (SiLU) and then multiply it by up — so one signal decides how much of the other gets through. It's just like a channel on an audio mixer: the first fader is the sound itself, the second fader is the knob that ducks or boosts it. This learnable gate gives the model the flexibility to say "play up this feature in this context, suppress that one"; LLaMA, Mistral, and Gemma all rely on it.

What the kernels actually do is simple: apply a small math formula per element. The real subtlety isn't there — functions like exp() can blow up to infinity on large inputs and produce NaN. So the naive formula isn't enough; as you'll see below, a sign-based branch is used to build a numerically stable SiLU.


What Does It Do?

GeLU

Gaussian Error Linear Unit. Smooth, non-monotone (there is a tiny dip at negatives):

gelu(x) ≈ 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³)))

This is the tanh approximation. Exact GeLU is x · Φ(x) where Φ is the standard normal CDF — expensive. The tanh form is both fast and ~99% accurate.

SwiGLU

The heart of LLaMA-style FFN. Two separate projections (gate, up) → element-wise multiply of up by SiLU(gate):

silu(x) = x · sigmoid(x)
swiglu_combine(gate, up): out[i] = silu(gate[i]) · up[i]

The full SwiGLU FFN:

hidden = swiglu_combine(W_gate(x), W_up(x))     ← 2 matmuls + this kernel
out = W_down(hidden)                             ← 1 matmul

vs. GeLU FFN:

hidden = gelu(W_up(x))                           ← 1 matmul + gelu
out = W_down(hidden)                             ← 1 matmul

SwiGLU costs 1 extra matmul but delivers better quality. LLaMA, Mistral, Falcon, and Gemma all use SwiGLU. ~5-10% lower loss.


The Math

GeLU

gelu(x) = 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³)))

with constants:
  √(2/π) ≈ 0.7978845608
  0.044715 (paper-derived coefficient)

SiLU (a.k.a. Swish-1)

silu(x) = x · sigmoid(x) = x / (1 + exp(-x))

SwiGLU combine

swiglu_combine(g, u)[i] = silu(g[i]) · u[i]
                       = (g[i] / (1 + exp(-g[i]))) · u[i]

gate (g) and up (u) come from matmuls — the typical FFN intermediate dim is 4× d_model. So both have size d_ff=3072.


Bind Group ABI

gelu_inplace (2 bindings)

BindingTypeDetail
0storage, read_writex: array<f32>[seq × d_ff]
1uniformn: u32 — total element count

swiglu_combine (4 bindings)

BindingTypeDetail
0storage, readgate: array<f32>
1storage, readup: array<f32>
2storage, read_writehidden: array<f32> — output
3uniformn: u32

Dispatch Shape

workgroup_size: 256
threads:        ceil(n / 256) workgroups × 256

n = seq × d_ff (for example 512 × 3072 = 1.57M). 6144 WG.

One thread = one element. Fully parallel, no shared memory, no barriers.


Line by Line

gelu_inplace

wgsl
@compute @workgroup_size(256, 1, 1)
fn gelu_inplace(@builtin(global_invocation_id) gid: vec3<u32>,
                @builtin(num_workgroups) nwg: vec3<u32>) {
    let i = flat_id(gid, nwg);
    if (i >= n) { return; }
    x[i] = gelu(x[i]);
}

Trivial — flat_id, bounds check, in-place transform.

gelu() helper:

wgsl
fn gelu(x: f32) -> f32 {
    let xc = clamp(x, -100.0, 100.0);
    let inner = 0.7978845608 * (xc + 0.044715 * xc * xc * xc);
    return 0.5 * xc * (1.0 + tanh(inner));
}

Why clamp(-100, 100)? overflows on very large inputs. x=100 → x³ = 1M, x=300 → x³ = 27M, x=1000 → x³ = 1e9, x=10000 → x³ = 1e12. f32 max is ~3.4e38, still safe. But 0.044715 * x³ then becomes the tanh argument → tanh(very_large) = ±1 → the function is already saturated. The clamp is an early exit that prevents NaN propagation.

tanh() is a WGSL built-in. A native instruction on Apple GPU hardware.

swiglu_combine

wgsl
@compute @workgroup_size(256, 1, 1)
fn swiglu_combine(@builtin(global_invocation_id) gid: vec3<u32>,
                  @builtin(num_workgroups) nwg: vec3<u32>) {
    let i = flat_id(gid, nwg);
    if (i >= n) { return; }
    hidden[i] = silu(gate[i]) * up[i];
}

silu() helper:

wgsl
fn silu(x: f32) -> f32 {
    var sig: f32;
    if (x >= 0.0) {
        sig = 1.0 / (1.0 + exp(-x));
    } else {
        let e = exp(x);
        sig = e / (1.0 + e);
    }
    return x * sig;
}
  • Branched Stabilization: Instead of the old approach's clamp(x, -50, 50), the code branches on the sign of x. This keeps the argument of exp() always polarized negative (exp(-x) or exp(x)). It fully prevents, at the hardware level, exp() from seeing a positive argument and producing Inf (infinity), then forming NaN via an Inf/Inf indeterminacy (the underflow case rounds directly to zero and stays stable). With this approach there is no need for clamp bounds, and the silu(x) -> x asymptote is preserved flawlessly on large inputs.

In practice, the SwiGLU input (the matmul output) sits in the ~[-3, 3] range during normal training. This structure prevents the pipeline from being contaminated with NaN even during unusual gradient spikes.


Why Fused swiglu_combine?

Classic PyTorch:

python
gate_out = gate_proj(x)
up_out = up_proj(x)
silu_out = F.silu(gate_out)
hidden = silu_out * up_out

= 4 separate kernel calls (silu in-place + multiply separately, or silu + multiply fused).

Our swiglu_combine:

hidden[i] = silu(gate[i]) · up[i]

= 1 kernel call. Memory traffic:

  • Classic: silu in-place (1 read + 1 write of gate) + multiply (2 reads + 1 write) = 5 ops
  • Fused: 2 reads (gate, up) + 1 write (hidden) = 3 ops → 40% bandwidth savings

Plus one less dispatch overhead.


WGSL-Specific Notes

1. tanh() and exp() built-ins

Guaranteed by the WGSL spec. Optimizes down to hardware (native on Apple GPU).

2. clamp(x, lo, hi) — saturated arithmetic

The behavior of clamp(NaN, a, b) is unspecified in the WGSL spec. On Apple Metal it's "either NaN or a bound, usually a bound". It doesn't matter for our use (gradient finiteness is already checked in backward).

3. In-place vs separate output

The name gelu_inplace implies in-place — it really does overwrite the input. The old value cannot be recovered. If backward needs it, we have to keep an extra copy buffer in forward (our code does this — the pre_act buffer).

swiglu_combine writes a separate output buffer (hidden). For backward, gate and up stay preserved — no extra save.


Performance

From the profile snapshot:

  • swiglu_combine: 12 layers × ~186 µs = 2.2 ms total = 0.4% of the step

Very small. Element-wise, memory-bound, not a hot kernel.


Code Review

Finding 1: GeLU clamp ±100 is perhaps unnecessary

RiskExplanation
🟢 nonex=100 → x³=1M → 0.044715 × 1M = 44715 → tanh(44715) ≈ 1.0 (saturated). Clamp is safety, never hit on the hot path.

Finding 2: gelu_inplace needs pre-act for backward

RiskExplanation
🟢 (architectural)When backward GeLU computes dx = ∂gelu/∂x · dy, it needs to know the forward x value. Since in-place GeLU overwrites x, backward needs a copy of it — this is held host-side in the pre_act buffer. Extra memory but unavoidable.

Finding 3: SwiGLU has no pre-activation save

RiskExplanation
🟢 noneSwiGLU backward needs the silu derivative. gate[i] is already a separate buffer, preserved. up[i] is preserved too. No save problem.

Quick Checklist

Test ScenarioStatus
Is gelu(0) = 0?✅ formula
gelu(very_large) no overflow?✅ clamp
Is silu(0) = 0?✅ formula
Is swiglu_combine aliasing-safe? (gate/up/hidden separate buffers)✅ runtime check
Does n = 0 not crash?✅ bounds check passes on the first thread
GeLU vs PyTorch ref same?⚠ no formal comparison test, but loss curves are reasonable

Next

07_loss.md — Cross-entropy + sum_losses. The last step of forward, the start of backward.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.