llm.istanbul·Study
TR EN
Workbench →

matmul and matmul_residual — Tiled Matrix Multiply (Forward)

File: 03_linear.wgsl (forward kernels) Pipeline step: In nearly every layer. Attention Q/K/V/O projection, FFN gate/up/down, lm_head — the most expensive operations in the model.

Forward kernels:

  • matmulY = X @ W (with double-buffering and vec4 loads)
  • matmul_residualY = X @ W + R (residual fused)
  • matmul_residual_swiglu_aY = (silu(GATE) * UP) @ W + R (FFN SwiGLU and matmul forward fusion)
  • matmul_w16 — f16 weight version (mixed precision)
  • matmul_residual_w16
  • matmul_residual_swiglu_a_w16 — Fused FFN forward with f16 weights

The backward kernels (matmul_t, matmul_t_acc, matmul_at, matmul_at_acc, matmul_at_swiglu_a, matmul_at_acc_swiglu_a) are examined in 11_backward_linear.md.


Wait, what is this?

Picture a restaurant kitchen. For every plate (an output row), you take the ingredients you've got on hand (one row of the input vector X) and match them against a whole stack of recipes (the columns of the weight matrix W), weighing and mixing as you go. "This much of that, that much of this" — you weigh and sum the ingredients for each recipe. That's exactly what a matrix multiply is: thousands of tiny "weigh-and-sum" operations (FMA, i.e. multiply-and-add) happening at once. This is why a transformer spends more than half its time right here; it's the busiest kitchen in the whole model.

And the entire bottleneck is moving ingredients around. A cook's actual job is cooking, but in a crowded kitchen most of the time goes to walking to the pantry and hauling stuff back. The GPU is the same: the multiply-add math is basically free, the expensive part is pulling data out of memory. So instead of carrying one item at a time, we move it by the crate — that's the vec4 load, grabbing four numbers in a single trip. We also chop the work into 64×64 "counter sections" (tiles) and stage just enough for one counter on the fast shared shelf (shared memory).

Here's the slick part: while one cook is cooking with what's on the counter, an assistant has already fetched the next batch from the pantry and laid it out on the empty counter beside them. The moment the current batch is done, there's no waiting — the next one is ready. That's what "double-buffering" means: two counters, one cooking while the other is being prepped, so the memory wait hides behind the math.

One bonus on top: in the FFN layer we fold two separate jobs (the SwiGLU mix + the following multiply) into a single kitchen. Normally you'd dump the intermediate mix into a bowl and pick it back up again; instead we mix it on the fly as it lands on the counter and feed it straight into the multiply. The bowl never gets dirty, and a mountain of memory traffic vanishes. Below you'll walk through how all of this is set up line by line, offsets and barriers and all.


What Does It Do?

Classic dense matrix multiply:

Y[M, N] = X[M, K] @ W[K, N]

In the forward pass:

  • M = seq_len (token count, 512)
  • K = d_model (768)
  • N = output dim — 768 for Attention Q/K/V/O, 3072 for FFN gate/up/down, 16384 for lm_head.

matmul_residual additionally adds the residual Y += R[M, N] in the epilogue. Fused for the layer connection.

matmul_residual_swiglu_a combines the two largest operations of the FFN layer: the element-wise silu(GATE) * UP combination (SwiGLU) followed by the W_down matrix multiply (hidden @ W_down + R). This way the hidden intermediate matrix is never written to memory, yielding both a single dispatch and substantial memory bandwidth savings.

This kernel group is the hottest path in the pipeline. Over 50% of the performance budget is spent here.


Algorithm — 64×64 Output Tile, 4×4 Sub-tile per Thread

Tile Hierarchy

Output Y[M, N]
  ↓ 64×64 tiles
  Each workgroup writes ONE 64×64 tile
    ↓ 16×16 thread workgroup
    Each thread writes 4×4 sub-tile (16 register accumulators)
      ↓ K-dim tiling: TK=16 elements per inner loop iteration
      Cooperative tile-load 64×16 (A) and 16×64 (B) into shared memory

Numbers

SizeValueDescription
TM64Output tile rows
TN64Output tile cols
TK16K-dim block per inner loop
TK_PAD17tileA stride (bank conflict avoidance)
Workgroup16×16=256 thread
Normal Mode (Single-Buffer)
Tile A64 × 17 × 4B = 4.4 KBf32
Tile B16 × 64 × 4B = 4 KBf32
Double-Buffered Mode (Double-Buffered matmul)
Tile A_db2 × 64 × 17 × 4B = 8.5 KBf32 (double-buffered)
Tile B_db2 × 16 × 64 × 4B = 8 KBf32 (double-buffered)
Total wg memory (db)~16.5 KBFits very comfortably within the Apple GPU max 32 KB limit
Per-thread accumulators16 register f32(4×4 sub-tile)

Key Optimizations

1. Vec4 Tile Loads (16-Byte Vectorized Access)

The X and W inputs are bound to memory as array<vec4<f32>> (or vec4<f16> for mixed precision). Each thread pulls data via a single 16-byte vectorized load instead of 4 separate scalar loads.

  • Constraints: The host side must guarantee K % 4 == 0 and N % 4 == 0 on the matrix dimensions.

2. Double-Buffering

In the matmul and matmul_w16 kernels, two double-sized workgroup buffers named tileA_db and tileB_db are declared.

  • Working Principle: While the 4×4 FMA computations on the current tile proceed on the GPU ALUs during the t loop, the data for the next t + 1 step is pulled asynchronously from memory into the free workgroup buffer (prefetch).
  • Gain: Memory latency is hidden behind computation, and the in-loop barrier frequency is halved.

3. SwiGLU Forward Fusion

In the matmul_residual_swiglu_a and matmul_residual_swiglu_a_w16 kernels, while the A tile is loaded from memory the silu(gate) * up operation is computed on-the-fly (at load time) and written into tileA_db.

  • Gain: The intermediate activation matrix of the FFN layer ([seq_len, d_ff]) is entirely spared from being written to memory and read back. This yields enormous memory bandwidth savings per layer.

Bind Group ABI

matmul (4 bindings)

BindingTypeDetail
0storage, readX: array<vec4<f32>>[M × K/4] row-major
1storage, readW: array<vec4<f32>>[K × N/4] row-major
2storage, read_writeY: array<f32>[M × N]
3uniformdims: vec4<u32>(M, N, K, _)

matmul_residual (5 bindings)

Same as above + R: array<f32> [M × N] residual input.

matmul_residual_swiglu_a (6 bindings)

BindingTypeDetail
0storage, readGATE_mrs: array<vec4<f32>> — FFN gate projection output
1storage, readUP_mrs: array<vec4<f32>> — FFN up projection output
2storage, readW_mrs: array<vec4<f32>>W_down weight matrix
3storage, readR_mrs: array<f32> — layer input residual
4storage, read_writeY_mrs: array<f32> — final layer output
5uniformdims_mrs: vec4<u32>(M, N, K, _)

Dispatch Shape

workgroup_size: (16, 16, 1) → 256 threads
grid: (ceil(N/64), ceil(M/64), 1) workgroups

Line by Line — Double-Buffered matmul

The critical sections of the double-buffered, vec4-loaded matmul kernel:

1) Entry and Load Structure

wgsl
@compute @workgroup_size(16, 16, 1)
fn matmul(@builtin(workgroup_id) wgid: vec3<u32>,
          @builtin(local_invocation_id) lid: vec3<u32>) {
    let M = dims.x; let N = dims.y; let K = dims.z;
    let K4 = K / 4u;
    let N4 = N / 4u;
    let tx = lid.x; let ty = lid.y;
    let tid = ty * 16u + tx;

    let block_row = wgid.y * TM;
    let block_col = wgid.x * TN;

2) Vectorized A / B Tile Initial Load (Prologue)

Before entering the loop, the first tile data for t = 0 is pulled with 16-byte vec4 reads:

wgsl
    {
        let aI0 = tid * 4u;
        let aIm = aI0 / TK; let aIk = aI0 % TK;
        let axr = block_row + aIm; let axc = aIk;
        let row_in = axr < M;
        // A single vec4 read pulls 4 floats at once
        let xv = X[axr * K4 + axc / 4u];
        tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K);
        tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K);
        tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K);
        tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K);

        let bI0 = tid * 4u;
        let bIk = bI0 / TN; let bIn = bI0 % TN;
        let bwr = bIk; let bwc = block_col + bIn;
        let bwr_in = bwr < K;
        let wv = W[bwr * N4 + bwc / 4u];
        tileB_db[bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N);
        tileB_db[bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N);
        tileB_db[bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N);
        tileB_db[bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N);
    }
    workgroupBarrier();

3) Prefetch and Compute Loop

Inside the loop, while the t + 1 data is loaded into the next buffer (via the nxt_a_off, nxt_b_off offsets), the data in the current cur buffers is multiplied:

wgsl
    for (var t: u32 = 0u; t < nTiles; t = t + 1u) {
        let parity = t & 1u;
        let cur_a_off = parity * TA_DB_HALF;
        let cur_b_off = parity * TB_DB_HALF;

        if (t + 1u < nTiles) {
            let nxt_a_off = (1u - parity) * TA_DB_HALF;
            let nxt_b_off = (1u - parity) * TB_DB_HALF;
            let nxt_kBase = (t + 1u) * TK;

            let aI0 = tid * 4u;
            let aIm = aI0 / TK; let aIk = aI0 % TK;
            let axr = block_row + aIm; let axc = nxt_kBase + aIk;
            let row_in = axr < M;
            let xv = X[axr * K4 + axc / 4u];
            tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K);
            tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K);
            tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K);
            tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K);

            let bI0 = tid * 4u;
            let bIk = bI0 / TN; let bIn = bI0 % TN;
            let bwr = nxt_kBase + bIk; let bwc = block_col + bIn;
            let bwr_in = bwr < K;
            let wv = W[bwr * N4 + bwc / 4u];
            tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N);
            tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N);
            tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N);
            tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N);
        }

        // 16 FMA computations with the current buffer
        for (var k: u32 = 0u; k < TK; k = k + 1u) {
            let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k];
            let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k];
            let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k];
            let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k];
            let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)];
            let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)];
            let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)];
            let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)];
            acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03);
            acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13);
            acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23);
            acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33);
        }
        workgroupBarrier();
    }

WGSL-Specific Notes

1. var<workgroup> and the Double-Buffer Cost

The memory size must be a compile-time constant. 2 * 1088 (tileA_db) + 2 * 1024 (tileB_db) floats = 16.5 KB of workgroup memory. Since the WebGPU hardware limit is 32 KB, excellent speedup is achieved without any occupancy loss.

2. fma() and Vector Optimizations

The FMAs inside the loop (acc = fma(a, b, acc)) directly trigger the GPU's Fused Multiply-Add hardware accelerator. Thanks to the vectorized bindings, the load on the memory bus is lightened by 75%.


Code Review

Finding 1: Cost of Bounds Masking with select

RiskDescription
🟢 noneIn the select(0.0, X[...], okay) expression, even if out-of-bounds reads are issued, WebGPU's robust buffer access prevents any hardware lockup and OOB data is masked out. Since M and N are multiples of 64, the edge-branch runs at zero cost in practical training.

Quick Checklist

Test ScenarioStatus
Are K and N multiples of 4? (vec4 requirement)✅ Verified on the host side
Does double buffering prefetch run race-free?workgroupBarrier() correctness
Did SwiGLU forward fusion zero out the intermediate buffer requirement?hidden buffer eliminated
Were bank conflicts prevented?TK_PAD = 17

Next

04_rope.md — Rotary Position Embedding. Applies position-based rotation to the Q and K vectors.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.