/** * backward_ffn.wgsl — FFN element-wise backward. * * ffn_gelu_backward: * grad_pre[i] = grad_hidden[i] * gelu'(pre[i]) * * ffn_swiglu_backward: * grad_gate[i] = grad_hidden[i] * up[i] * silu'(gate[i]) * grad_up[i] = grad_hidden[i] * silu(gate[i]) * * Both kernels are 1D element-wise — no GEMM here. The `matmul` / * `matmul_t` / `matmul_at` kernels in linear.wgsl handle the surrounding * weight-and-input gradients. */ fn gelu_derivative(x_in: f32) -> f32 { // d/dx [0.5 * x * (1 + tanh(inner))] inner = sqrt(2/pi)*(x + 0.044715*x^3) let x = clamp(x_in, -100.0, 100.0); let sqrt_2_over_pi = 0.7978845608; let coeff = 0.044715; let x2 = x * x; let x3 = x2 * x; let inner = sqrt_2_over_pi * (x + coeff * x3); let t = tanh(inner); let sech2 = 1.0 - t * t; let d_inner = sqrt_2_over_pi * (1.0 + 3.0 * coeff * x2); return 0.5 * (1.0 + t) + 0.5 * x * sech2 * d_inner; } fn sigmoid(x: f32) -> f32 { // Stable sigmoid — branch on sign so exp() only sees a non-positive // argument (can't overflow; underflow → 0 is the correct limit). No // clamp needed. Matches the forward `silu` idiom in 06_activation.wgsl. if (x >= 0.0) { return 1.0 / (1.0 + exp(-x)); } let e = exp(x); return e / (1.0 + e); } fn silu(x: f32) -> f32 { return x * sigmoid(x); } fn silu_derivative(x: f32) -> f32 { let s = sigmoid(x); return s * (1.0 + x * (1.0 - s)); } // --- KERNEL: ffn_gelu_backward --- // grad_pre[i] = grad_hidden[i] * gelu'(pre[i]) @group(0) @binding(0) var grad_hidden: array; @group(0) @binding(1) var pre_gelu: array; @group(0) @binding(2) var grad_pre: array; @group(0) @binding(3) var n: u32; @compute @workgroup_size(256, 1, 1) fn ffn_gelu_backward(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { let i = flat_id(gid, nwg); if (i >= n) { return; } let g = grad_hidden[i]; let pg = pre_gelu[i]; let r = g * gelu_derivative(pg); grad_pre[i] = nan_guard(r); } // --- KERNEL: ffn_swiglu_backward --- // grad_gate[i] = grad_hidden[i] * up[i] * silu'(gate[i]) // grad_up[i] = grad_hidden[i] * silu(gate[i]) @group(0) @binding(0) var grad_hidden: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var up: array; @group(0) @binding(3) var grad_gate: array; @group(0) @binding(4) var grad_up: array; @group(0) @binding(5) var n: u32; @compute @workgroup_size(256, 1, 1) fn ffn_swiglu_backward(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { let i = flat_id(gid, nwg); if (i >= n) { return; } let g = grad_hidden[i]; let gt = gate[i]; let u = up[i]; let s = silu(gt); let dg = g * u * silu_derivative(gt); let du = g * s; grad_gate[i] = nan_guard(dg); grad_up[i] = nan_guard(du); }