/** * activation.wgsl — element-wise activations + their fused FFN combiners. * * gelu_inplace: x = gelu(x) (tanh approximation) * silu_inplace: x = x * sigmoid(x) (a.k.a. swish) * swiglu_combine: hidden = silu(gate) * up (per-element) */ fn gelu(x: f32) -> f32 { // Tanh approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) let xc = clamp(x, -100.0, 100.0); let inner = 0.7978845608 * (xc + 0.044715 * xc * xc * xc); return 0.5 * xc * (1.0 + tanh(inner)); } fn silu(x: f32) -> f32 { // silu(x) = x * sigmoid(x). Stabilize the EXPONENT, not x itself — // clamping x (the old approach) capped the result at ~±50, but // silu(x) → x for large |x|, so silu(100) must be ~100, not ~50. // Branch on sign so exp() only ever sees a non-positive argument // (range (-inf, 0] → exp in (0, 1], no overflow; underflow → 0 is // the correct limit). No clamp needed. var sig: f32; if (x >= 0.0) { sig = 1.0 / (1.0 + exp(-x)); } else { let e = exp(x); sig = e / (1.0 + e); } return x * sig; } // --- KERNEL: gelu_inplace --- @group(0) @binding(0) var x: array; @group(0) @binding(1) var n: u32; @compute @workgroup_size(256, 1, 1) fn gelu_inplace(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { let i = flat_id(gid, nwg); if (i >= n) { return; } x[i] = gelu(x[i]); } // --- KERNEL: swiglu_combine --- // hidden[i] = silu(gate[i]) * up[i] @group(0) @binding(0) var gate: array; @group(0) @binding(1) var up: array; @group(0) @binding(2) var hidden: array; @group(0) @binding(3) var n: u32; @compute @workgroup_size(256, 1, 1) fn swiglu_combine(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { let i = flat_id(gid, nwg); if (i >= n) { return; } hidden[i] = silu(gate[i]) * up[i]; }