00_infrastructure.wgsl — Element-wise Utility Kernels
File: 00_infrastructure.wgsl Pipeline step: none — These are glue kernels, used between forward/backward/optimizer.
Wait, what is this?
Picture an array with millions of numbers in it. And you want to do the most boring thing imaginable: zero them all out. Or multiply each one by 2. Or "add this list on top of that one". In JavaScript these are one-liners: arr.fill(0), arr.map(x => x * 2), a[i] += b[i]. The kind of for loops you write without thinking.
This file is exactly those — but on the GPU. The model's heavy lifting (matrix multiply, attention) lives in other files; this is the drawer of little helpers that wire those heavy ops together. fill_zero zeros an array, scale multiplies every element by a number, axpy adds one array (scaled) on top of another, copy copies, clamp pushes overflowing values back inside a bound. They all share the same shape — "apply the same simple thing to every element" — i.e. element-wise.
In a word: these are the spoons and knives of the kitchen. They're not the stars, but no dish reaches the plate without them. Zeroing a gradient buffer every step, updating moments in the Adam optimizer, doing gradient clipping — each one calls one of these tiny kernels. Glue that does little but shows up everywhere.
So why does a one-line loop earn its own file on the GPU? Because spreading "touch every element" across thousands of threads means figuring out which thread looks at which element, and dropping the leftover threads past the end of the array. The little machinery behind the GPU side of something this simple is below.
What Does It Do?
Provides 6 small element-wise kernels — the model's "glue" operations:
| Kernel | Operation | Typical use |
|---|---|---|
fill_zero | dst[i] = 0 | Zeroing out gradient buffers |
fill_const | dst[i] = value | Init / debug |
scale | dst[i] *= alpha | Gradient clipping (grad *= clip_scale) |
axpy | dst[i] += alpha * src[i] | Residual update, Adam moment update |
copy | dst[i] = src[i] | Tensor copying |
clamp_inplace | dst[i] = clamp(dst[i], -max, max) | NaN/overflow guard |
All follow a single pattern: 1D dispatch, WG=256, global index via flat_id, i < n bounds check, no atomic functions.
Mathematical Definitions
fill_zero: dst[i] = 0 ∀ i ∈ [0, n)
fill_const: dst[i] = c ∀ i ∈ [0, n)
scale: dst[i] ← α · dst[i] ∀ i ∈ [0, n)
axpy: dst[i] ← dst[i] + α · src[i] ∀ i ∈ [0, n) (BLAS-1 axpy)
copy: dst[i] = src[i] ∀ i ∈ [0, n)
clamp_inplace: dst[i] ← max(-c, min(c, dst[i])) ∀ i ∈ [0, n)The name axpy comes from the BLAS-1 standard: "a · x plus y". The oldest routine in numerical computing.
Bind Group ABI
The same across all kernels: 2–4 bindings on @group(0).
fill_zero (2 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | dst: array<f32> — destination to be written |
| 1 | uniform | n: u32 — element count |
fill_const, scale (3 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | dst |
| 1 | uniform | n: u32 |
| 2 | uniform | value or alpha: f32 |
axpy (4 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | dst |
| 1 | storage, read | src |
| 2 | uniform | n: u32 |
| 3 | uniform | alpha: f32 |
copy (3 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | dst |
| 1 | storage, read | src |
| 2 | uniform | n: u32 |
clamp_inplace (3 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | dst |
| 1 | uniform | n: u32 |
| 2 | uniform | max_abs: f32 |
Dispatch Shape
workgroup_size: 256
total threads: ceil(n / 256) workgroups × 256
threadgroup mem: 0 (none of them use workgroup memory)Host (engine.dispatch1D or similar):
const wgCount = Math.ceil(n / 256);
const dim = wgCount <= MAX_WG_DIM
? [wgCount, 1, 1]
: [MAX_WG_DIM, Math.ceil(wgCount / MAX_WG_DIM), 1];
pass.dispatchWorkgroups(...dim);flat_id(gid, nwg) linearizes both 1D and 2D dispatch indistinguishably — see 00_shared.md.
Line-by-Line Explanation (axpy — representative)
axpy is the richest example; the others simplify the template.
1) Bind group declarations
@group(0) @binding(0) var<storage, read_write> dst: array<f32>;
@group(0) @binding(1) var<storage, read> src: array<f32>;
@group(0) @binding(2) var<uniform> n: u32;
@group(0) @binding(3) var<uniform> alpha: f32;dstandsrcare separate binds — in WGSL,storage, readandstorage, read_writecannot alias the same buffer (must be non-aliasing from the compiler's point of view). Ifdst === srcis desired, this kernel cannot be used; a different kernel is needed.nandalphaare uniform — small values, behaving like push-constants. With the read-only access guarantee, they stay in the GPU constant cache.
2) Kernel signature
@compute @workgroup_size(256, 1, 1)
fn axpy(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) nwg: vec3<u32>) {@workgroup_size(256, 1, 1)— 256 threads per WG, compile-time constant as a kernel attribute. Unlike Metal's[[max_total_threads_per_threadgroup(N)]]hint: in WGSL this is a guarantee.global_invocation_id(Metal'sthread_position_in_grid) — this thread's global coordinate within the entire dispatch grid.num_workgroups— the x, y, z values from thedispatchWorkgroups(x, y, z)call.
3) Linearize
let i = flat_id(gid, nwg);Details in the 00_shared.md section. The result: the thread's flat 1D index.
4) Bounds check
if (i >= n) { return; }With WG=256, n may not be a multiple of 256 (e.g. n=1027 → 4×256 = 1024 threads suffice, but 5×256 = 1280 are dispatched; the last 253 threads are out of bounds). A defensive check that catches this.
5) Element-wise operation
dst[i] = dst[i] + alpha * src[i];A single line. dst[i] and src[i] are two separate loads from global memory; * and + are ALU; a single store.
Memory pattern: Adjacent threads read adjacent memory — coalesced access (each warp hits a single 256-byte cache line). Bandwidth-optimal.
Variants — What's the Difference?
fill_zero vs fill_const
fill_zero is a dedicated kernel because filling with 0 is very common (gradient zeroing every step), and dropping the value: f32 uniform binding means one fewer binding → faster bind group setup. Does it matter in practice? Marginal — but a natural optimization for the hot path.
scale vs axpy
axpy requires 2 buffers (dst + src). If you only want dst *= alpha, use scale (1 buffer). Fewer reads = less memory traffic.
clamp_inplace
Contains a symmetric clamp (between -max_abs and +max_abs). If you need an asymmetric clamp (e.g. ReLU = clamp(x, 0, +Inf)), this kernel is incompatible; a different formula is needed.
No Atomics — An Important Detail
The comment at the top of the file:
// All buffers are f32. Atomics are used for multi-writer gradient
// accumulation (CAS-add via bitcast — WGSL has no native atomic<f32>).But none of them in this file use atomics — each thread writes to a single slot (dst[i]), and since each slot is touched by a single thread, there is no race.
Atomic functions are used elsewhere:
attention_backward(pre-split version) — CAS-add to dK/dVembed_backward— the same token ID at multiple positions → atomicreduce_norm_sq— atomicAdd to scalar
This utility file is in the "single-writer" pattern — no atomics needed, faster.
WGSL-Specific Notes
1. Buffer aliasing forbidden
@group(0) @binding(0) var<storage, read_write> dst: array<f32>;
@group(0) @binding(1) var<storage, read> src: array<f32>;dst and src must be different buffers. Binding the same GPU buffer to two bindings raises an error in the WebGPU runtime (bind group validation).
If you want "in-place axpy" (x += alpha * x), this kernel cannot be used; either a new kernel (scale_in_place_axpy: x *= 1+alpha) or the scale + alpha=1+α trick.
2. array<f32> without size
var<storage, read_write> dst: array<f32>;In WGSL, the size of a storage array is runtime determined (from the host buffer size). array<f32, N> (fixed length) is only used in workgroup memory.
3. Uniform buffer alignment
@group(0) @binding(2) var<uniform> n: u32;
@group(0) @binding(3) var<uniform> alpha: f32;Each uniform is a separate binding — meaning 4 separate buffers on the host side. Instead, it could also be done as a struct:
struct Params { n: u32, alpha: f32 }
@group(0) @binding(2) var<uniform> params: Params;But our code's stylistic preference: single-value uniforms kept separate. The pipeline binding count is higher but the code is more readable.
4. i32 vs u32 index
WGSL array indexing requires u32. In Metal there's no int hassle, everything is int. In WGSL you must declare i: u32, otherwise you end up adding a dst[i32(i)] cast.
Code Review
Finding 1: No atomics, but the name/comment implies atomics
| Risk | Explanation |
|---|---|
| 🟢 none | The file comment says "Atomics are used for multi-writer gradient accumulation" but there are no atomics in this file. The comment is a top-of-file annotation — for the pipeline in general, not for the infra. It may be an explanatory comment, but it's not a problem in practice. |
Finding 2: Why are scale and clamp_inplace limited to f32?
| Risk | Explanation |
|---|---|
| 🟡 maybe later | In mixed precision, scale would fail if called on an f16 weight. Right now scale is only used in gradient clipping (gradients are always f32). But it's not future-proof. |
Mitigation: If f16 grad is ever needed, a scale_w16 variant gets added. Not needed for now.
Finding 3: fill_zero vs fill_const(0)
| Risk | Explanation |
|---|---|
| 🟢 none | Two separate kernels — fill_const(0) would also work, but there would be a redundant value binding in the bind group. On the hot path fill_zero should be preferred, and the code is written that way. |
Quick Checklist
| Test Scenario | Status |
|---|---|
Does it work correctly for an n that is not a multiple of 256? | ✅ bounds check present |
Does it avoid crashing for n=0? | ✅ the first flat_id is already ≥ 0, 0 ≥ 0 is false, no entry changes (correct no-op) |
What happens if axpy aliases the same buffer? | ✅ raises a runtime error — validation at the engine level |
What does clamp_inplace do with NaN? | ⚠ clamp(NaN, -c, c) returns NaN in WGSL — may be undesirable, but this kernel is not a NaN-guard anyway |
Next
01_embedding.md — extracting the actual vector from a token ID. The model's first forward step.