/** * linear.wgsl — tiled matrix multiply (64×64 output tile, 4×4 sub-tile per thread). * * matmul: Y = X @ W X is [M, K], W is [K, N], Y is [M, N] * matmul_residual: Y = X @ W + R (forward fusion: residual added in epilogue) * matmul_t: Y = X @ W^T W is [N, K] (rows are the N output features) * matmul_at: Y[K, N] = X^T @ B X is [M, K], B is [M, N] (for dW) * matmul_at_acc: Y[K, N] += X^T @ B (accumulating variant for grad accum) * matmul_t_acc: Y[M, N] += X @ W^T (accumulating fused matmul_t for QKV/Wk/Wv) * * Tile structure: * - workgroup_size = (16, 16, 1) → 256 threads * - Each workgroup writes a 64×64 output tile (TM × TN); 16× more output * than the previous 32×32 tile, so 4× fewer dispatches per matmul. * - Each thread holds a 4×4 register accumulator (16 elements; was 2×2). * - K is tiled at TK=16; loads cooperatively bring 64×16 and 16×64 tiles * into workgroup memory (~8.5 KB vs ~4 KB before). * * Why this beats 32×32 / 2×2: * - Each shared-mem value reused 4× per thread (vs 2×) → ½ the shared-mem * traffic per multiply-accumulate, 2× arithmetic intensity in the inner loop. * - 16 register-resident accumulators per thread vs 4 → bigger inner-loop * hides shared-mem load latency better. * - ¼ the workgroups → less per-WG launch overhead, better SM occupancy * amortization. * * Coalesced reads on transposed paths: matmul_at swaps the aI mapping so * consecutive threads load consecutive K values at fixed M (instead of * stride-K jumps); matmul_t / matmul_t_acc do the same for W^T. Without * this, the tile load on transposed inputs hits ~32 cache lines per warp; * with it, just 2-4. (Apple GPU's L1 happens to absorb the older stride-K * pattern fairly well, so the realised gain from coalescing alone is small; * the bigger tile is what moves the needle.) * * fp32 throughout. Mixed-precision (f16 weight storage) variants live in * linear_w16.wgsl and mirror the same structure. */ const TM: u32 = 64u; // output rows per workgroup const TN: u32 = 64u; // output cols per workgroup const TK: u32 = 16u; // K-dim tile width // tileA padding: stride=16 floats lands threads with different ty on the // same 32-bank bank in the inner loop. Stride=17 is coprime with 32 so // no bank conflicts inside the warp. tileB has no such pattern (different // threads → different N columns) and stays unpadded. const TK_PAD: u32 = 17u; var tileA: array; // 64 × 17 (padded for bank conflicts) var tileB: array; // 16 × 64 // Double-buffered tiles for `matmul` only (perf-doc #6). Two copies of each // tile so the next K-tile prefetch can overlap with compute on the current // tile — memory latency hiding without an explicit barrier between the two // phases. Total: 2*1088 + 2*1024 = 8.5 KB workgroup memory; Apple's 32 KB // budget fits with room (~3 WGs/SM occupancy assuming similar compute mix). // Tint DCE keeps these out of the other matmul variants' WG-mem tally. const TA_DB_HALF: u32 = 1088u; const TB_DB_HALF: u32 = 1024u; var tileA_db: array; // 2 × 64 × 17 var tileB_db: array; // 2 × 16 × 64 // Stable SiLU — branch on sign so exp() only sees a non-positive arg. // Used by the `*_swiglu_a` matmul variants whose A-tile load fuses // `silu(gate[r,k]) * up[r,k]` instead of materializing `hidden`. // Mirrors the canonical version in 06_activation.wgsl; redeclared here // because each KERNEL is its own compiled module (no cross-file imports // in WGSL). Tint DCEs it out of kernels that don't reference it. fn silu(x: f32) -> f32 { var sig: f32; if (x >= 0.0) { sig = 1.0 / (1.0 + exp(-x)); } else { let e = exp(x); sig = e / (1.0 + e); } return x * sig; } // --- KERNEL: matmul --- // Y[M, N] = X[M, K] @ W[K, N] // // vec4 tile loads: X and W are bound as `array>` so each thread // fetches 4 consecutive floats in a single 16-byte load instead of 4 scalars. // Constraints (host MUST satisfy): K % 4 == 0 AND N % 4 == 0. Production // configs satisfy both (K = d_model | d_ff | head_dim, N = same set, all // multiples of 64). Buffers themselves keep f32 byte layout — only the // binding type changes. WebGPU mandates robust buffer access: an OOB vec4 // index is bounds-checked — the load never traps and never reads outside // the buffer binding. The returned value is impl-defined (zero, or some // other in-binding element) — NOT guaranteed zero. Correctness doesn't // rely on the value: the per-element `select` guards below discard any // element whose logical index is past K (end-of-row) or past M/N // (cross-row), so whatever robustness hands back is thrown away. @group(0) @binding(0) var X: array>; @group(0) @binding(1) var W: array>; @group(0) @binding(2) var Y: array; @group(0) @binding(3) var dims: vec4; // (M, N, K, _) @compute @workgroup_size(16, 16, 1) fn matmul(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims.x; let N = dims.y; let K = dims.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; // ── Prologue: load tile 0 into buffer 0 ── { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; // kBase = 0 let row_in = axr < M; let xv = X[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = bIk; // kBase = 0 let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W[bwr * N4 + bwc / 4u]; tileB_db[bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N); tileB_db[bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N); tileB_db[bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N); tileB_db[bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N); } workgroupBarrier(); // ── Pipelined loop: prefetch tile (t+1) and compute on tile t in parallel ── // Different buffers → no data dependency between prefetch and compute, // so the GPU can hide global-memory latency behind FMA work. Single // barrier per iteration (was 2 in the un-buffered version). for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; // Prefetch tile (t+1) into the OTHER buffer (skip on last iter). if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = nxt_kBase + bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W[bwr * N4 + bwc / 4u]; tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N); } // 4×4 sub-tile per thread on the CURRENT buffer. for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; // Row 0 if (or + 0u < M && oc + 0u < N) { Y[(or + 0u) * N + (oc + 0u)] = acc00; } if (or + 0u < M && oc + 1u < N) { Y[(or + 0u) * N + (oc + 1u)] = acc01; } if (or + 0u < M && oc + 2u < N) { Y[(or + 0u) * N + (oc + 2u)] = acc02; } if (or + 0u < M && oc + 3u < N) { Y[(or + 0u) * N + (oc + 3u)] = acc03; } // Row 1 if (or + 1u < M && oc + 0u < N) { Y[(or + 1u) * N + (oc + 0u)] = acc10; } if (or + 1u < M && oc + 1u < N) { Y[(or + 1u) * N + (oc + 1u)] = acc11; } if (or + 1u < M && oc + 2u < N) { Y[(or + 1u) * N + (oc + 2u)] = acc12; } if (or + 1u < M && oc + 3u < N) { Y[(or + 1u) * N + (oc + 3u)] = acc13; } // Row 2 if (or + 2u < M && oc + 0u < N) { Y[(or + 2u) * N + (oc + 0u)] = acc20; } if (or + 2u < M && oc + 1u < N) { Y[(or + 2u) * N + (oc + 1u)] = acc21; } if (or + 2u < M && oc + 2u < N) { Y[(or + 2u) * N + (oc + 2u)] = acc22; } if (or + 2u < M && oc + 3u < N) { Y[(or + 2u) * N + (oc + 3u)] = acc23; } // Row 3 if (or + 3u < M && oc + 0u < N) { Y[(or + 3u) * N + (oc + 0u)] = acc30; } if (or + 3u < M && oc + 1u < N) { Y[(or + 3u) * N + (oc + 1u)] = acc31; } if (or + 3u < M && oc + 2u < N) { Y[(or + 3u) * N + (oc + 2u)] = acc32; } if (or + 3u < M && oc + 3u < N) { Y[(or + 3u) * N + (oc + 3u)] = acc33; } } // --- KERNEL: matmul_residual --- // Y[M, N] = X[M, K] @ W[K, N] + R[M, N] // vec4 tile loads — host MUST satisfy K % 4 == 0 AND N % 4 == 0. @group(0) @binding(0) var X_mr: array>; @group(0) @binding(1) var W_mr: array>; @group(0) @binding(2) var R_mr: array; @group(0) @binding(3) var Y_mr: array; @group(0) @binding(4) var dims_mr: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_residual(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mr.x; let N = dims_mr.y; let K = dims_mr.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; // Prologue: tile 0 → buffer 0 { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let xv = X_mr[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mr[bwr * N4 + bwc / 4u]; tileB_db[bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N); tileB_db[bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N); tileB_db[bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N); tileB_db[bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X_mr[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = nxt_kBase + bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mr[bwr * N4 + bwc / 4u]; tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; // Fused residual add in epilogue: Y = X@W + R. if (or + 0u < M && oc + 0u < N) { Y_mr[(or + 0u) * N + (oc + 0u)] = acc00 + R_mr[(or + 0u) * N + (oc + 0u)]; } if (or + 0u < M && oc + 1u < N) { Y_mr[(or + 0u) * N + (oc + 1u)] = acc01 + R_mr[(or + 0u) * N + (oc + 1u)]; } if (or + 0u < M && oc + 2u < N) { Y_mr[(or + 0u) * N + (oc + 2u)] = acc02 + R_mr[(or + 0u) * N + (oc + 2u)]; } if (or + 0u < M && oc + 3u < N) { Y_mr[(or + 0u) * N + (oc + 3u)] = acc03 + R_mr[(or + 0u) * N + (oc + 3u)]; } if (or + 1u < M && oc + 0u < N) { Y_mr[(or + 1u) * N + (oc + 0u)] = acc10 + R_mr[(or + 1u) * N + (oc + 0u)]; } if (or + 1u < M && oc + 1u < N) { Y_mr[(or + 1u) * N + (oc + 1u)] = acc11 + R_mr[(or + 1u) * N + (oc + 1u)]; } if (or + 1u < M && oc + 2u < N) { Y_mr[(or + 1u) * N + (oc + 2u)] = acc12 + R_mr[(or + 1u) * N + (oc + 2u)]; } if (or + 1u < M && oc + 3u < N) { Y_mr[(or + 1u) * N + (oc + 3u)] = acc13 + R_mr[(or + 1u) * N + (oc + 3u)]; } if (or + 2u < M && oc + 0u < N) { Y_mr[(or + 2u) * N + (oc + 0u)] = acc20 + R_mr[(or + 2u) * N + (oc + 0u)]; } if (or + 2u < M && oc + 1u < N) { Y_mr[(or + 2u) * N + (oc + 1u)] = acc21 + R_mr[(or + 2u) * N + (oc + 1u)]; } if (or + 2u < M && oc + 2u < N) { Y_mr[(or + 2u) * N + (oc + 2u)] = acc22 + R_mr[(or + 2u) * N + (oc + 2u)]; } if (or + 2u < M && oc + 3u < N) { Y_mr[(or + 2u) * N + (oc + 3u)] = acc23 + R_mr[(or + 2u) * N + (oc + 3u)]; } if (or + 3u < M && oc + 0u < N) { Y_mr[(or + 3u) * N + (oc + 0u)] = acc30 + R_mr[(or + 3u) * N + (oc + 0u)]; } if (or + 3u < M && oc + 1u < N) { Y_mr[(or + 3u) * N + (oc + 1u)] = acc31 + R_mr[(or + 3u) * N + (oc + 1u)]; } if (or + 3u < M && oc + 2u < N) { Y_mr[(or + 3u) * N + (oc + 2u)] = acc32 + R_mr[(or + 3u) * N + (oc + 2u)]; } if (or + 3u < M && oc + 3u < N) { Y_mr[(or + 3u) * N + (oc + 3u)] = acc33 + R_mr[(or + 3u) * N + (oc + 3u)]; } } // --- KERNEL: matmul_t --- // Y[M, N] = X[M, K] @ W^T where W is stored row-major as [N, K]. // B-tile holds W^T: tileB[k_local, n_local] = W[block_col+n_local, kBase+k_local]. // vec4 tile loads — host MUST satisfy K % 4 == 0 (X row stride = K, W row stride = K). @group(0) @binding(0) var X_mt: array>; @group(0) @binding(1) var W_mt: array>; @group(0) @binding(2) var Y_mt: array; @group(0) @binding(3) var dims_mt: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_t(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mt.x; let N = dims_mt.y; let K = dims_mt.z; let K4 = K / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; // Prologue: tile 0 → buffer 0 { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let xv = X_mt[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIn = bI0 / TK; let bIk = bI0 % TK; let nG = block_col + bIn; let kG = bIk; let n_in = nG < N; let wv = W_mt[nG * K4 + kG / 4u]; tileB_db[(bIk + 0u) * TN + bIn] = select(0.0, wv.x, n_in && (kG + 0u) < K); tileB_db[(bIk + 1u) * TN + bIn] = select(0.0, wv.y, n_in && (kG + 1u) < K); tileB_db[(bIk + 2u) * TN + bIn] = select(0.0, wv.z, n_in && (kG + 2u) < K); tileB_db[(bIk + 3u) * TN + bIn] = select(0.0, wv.w, n_in && (kG + 3u) < K); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X_mt[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIn = bI0 / TK; let bIk = bI0 % TK; let nG = block_col + bIn; let kG = nxt_kBase + bIk; let n_in = nG < N; let wv = W_mt[nG * K4 + kG / 4u]; tileB_db[nxt_b_off + (bIk + 0u) * TN + bIn] = select(0.0, wv.x, n_in && (kG + 0u) < K); tileB_db[nxt_b_off + (bIk + 1u) * TN + bIn] = select(0.0, wv.y, n_in && (kG + 1u) < K); tileB_db[nxt_b_off + (bIk + 2u) * TN + bIn] = select(0.0, wv.z, n_in && (kG + 2u) < K); tileB_db[nxt_b_off + (bIk + 3u) * TN + bIn] = select(0.0, wv.w, n_in && (kG + 3u) < K); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < M && oc + 0u < N) { Y_mt[(or + 0u) * N + (oc + 0u)] = acc00; } if (or + 0u < M && oc + 1u < N) { Y_mt[(or + 0u) * N + (oc + 1u)] = acc01; } if (or + 0u < M && oc + 2u < N) { Y_mt[(or + 0u) * N + (oc + 2u)] = acc02; } if (or + 0u < M && oc + 3u < N) { Y_mt[(or + 0u) * N + (oc + 3u)] = acc03; } if (or + 1u < M && oc + 0u < N) { Y_mt[(or + 1u) * N + (oc + 0u)] = acc10; } if (or + 1u < M && oc + 1u < N) { Y_mt[(or + 1u) * N + (oc + 1u)] = acc11; } if (or + 1u < M && oc + 2u < N) { Y_mt[(or + 1u) * N + (oc + 2u)] = acc12; } if (or + 1u < M && oc + 3u < N) { Y_mt[(or + 1u) * N + (oc + 3u)] = acc13; } if (or + 2u < M && oc + 0u < N) { Y_mt[(or + 2u) * N + (oc + 0u)] = acc20; } if (or + 2u < M && oc + 1u < N) { Y_mt[(or + 2u) * N + (oc + 1u)] = acc21; } if (or + 2u < M && oc + 2u < N) { Y_mt[(or + 2u) * N + (oc + 2u)] = acc22; } if (or + 2u < M && oc + 3u < N) { Y_mt[(or + 2u) * N + (oc + 3u)] = acc23; } if (or + 3u < M && oc + 0u < N) { Y_mt[(or + 3u) * N + (oc + 0u)] = acc30; } if (or + 3u < M && oc + 1u < N) { Y_mt[(or + 3u) * N + (oc + 1u)] = acc31; } if (or + 3u < M && oc + 2u < N) { Y_mt[(or + 3u) * N + (oc + 2u)] = acc32; } if (or + 3u < M && oc + 3u < N) { Y_mt[(or + 3u) * N + (oc + 3u)] = acc33; } } // --- KERNEL: matmul_t_acc --- // Y[M, N] += X[M, K] @ W^T (same as matmul_t but accumulating). // Used by backward to fold dQ/dK/dV @ W^T contributions into dx_norm // without an extra axpy + scratch buffer. // vec4 tile loads — host MUST satisfy K % 4 == 0. @group(0) @binding(0) var X_mta: array>; @group(0) @binding(1) var W_mta: array>; @group(0) @binding(2) var Y_mta: array; @group(0) @binding(3) var dims_mta: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_t_acc(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mta.x; let N = dims_mta.y; let K = dims_mta.z; let K4 = K / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let xv = X_mta[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIn = bI0 / TK; let bIk = bI0 % TK; let nG = block_col + bIn; let kG = bIk; let n_in = nG < N; let wv = W_mta[nG * K4 + kG / 4u]; tileB_db[(bIk + 0u) * TN + bIn] = select(0.0, wv.x, n_in && (kG + 0u) < K); tileB_db[(bIk + 1u) * TN + bIn] = select(0.0, wv.y, n_in && (kG + 1u) < K); tileB_db[(bIk + 2u) * TN + bIn] = select(0.0, wv.z, n_in && (kG + 2u) < K); tileB_db[(bIk + 3u) * TN + bIn] = select(0.0, wv.w, n_in && (kG + 3u) < K); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X_mta[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIn = bI0 / TK; let bIk = bI0 % TK; let nG = block_col + bIn; let kG = nxt_kBase + bIk; let n_in = nG < N; let wv = W_mta[nG * K4 + kG / 4u]; tileB_db[nxt_b_off + (bIk + 0u) * TN + bIn] = select(0.0, wv.x, n_in && (kG + 0u) < K); tileB_db[nxt_b_off + (bIk + 1u) * TN + bIn] = select(0.0, wv.y, n_in && (kG + 1u) < K); tileB_db[nxt_b_off + (bIk + 2u) * TN + bIn] = select(0.0, wv.z, n_in && (kG + 2u) < K); tileB_db[nxt_b_off + (bIk + 3u) * TN + bIn] = select(0.0, wv.w, n_in && (kG + 3u) < K); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; // Accumulating epilogue: Y[i] += acc. if (or + 0u < M && oc + 0u < N) { let i = (or + 0u) * N + (oc + 0u); Y_mta[i] = Y_mta[i] + acc00; } if (or + 0u < M && oc + 1u < N) { let i = (or + 0u) * N + (oc + 1u); Y_mta[i] = Y_mta[i] + acc01; } if (or + 0u < M && oc + 2u < N) { let i = (or + 0u) * N + (oc + 2u); Y_mta[i] = Y_mta[i] + acc02; } if (or + 0u < M && oc + 3u < N) { let i = (or + 0u) * N + (oc + 3u); Y_mta[i] = Y_mta[i] + acc03; } if (or + 1u < M && oc + 0u < N) { let i = (or + 1u) * N + (oc + 0u); Y_mta[i] = Y_mta[i] + acc10; } if (or + 1u < M && oc + 1u < N) { let i = (or + 1u) * N + (oc + 1u); Y_mta[i] = Y_mta[i] + acc11; } if (or + 1u < M && oc + 2u < N) { let i = (or + 1u) * N + (oc + 2u); Y_mta[i] = Y_mta[i] + acc12; } if (or + 1u < M && oc + 3u < N) { let i = (or + 1u) * N + (oc + 3u); Y_mta[i] = Y_mta[i] + acc13; } if (or + 2u < M && oc + 0u < N) { let i = (or + 2u) * N + (oc + 0u); Y_mta[i] = Y_mta[i] + acc20; } if (or + 2u < M && oc + 1u < N) { let i = (or + 2u) * N + (oc + 1u); Y_mta[i] = Y_mta[i] + acc21; } if (or + 2u < M && oc + 2u < N) { let i = (or + 2u) * N + (oc + 2u); Y_mta[i] = Y_mta[i] + acc22; } if (or + 2u < M && oc + 3u < N) { let i = (or + 2u) * N + (oc + 3u); Y_mta[i] = Y_mta[i] + acc23; } if (or + 3u < M && oc + 0u < N) { let i = (or + 3u) * N + (oc + 0u); Y_mta[i] = Y_mta[i] + acc30; } if (or + 3u < M && oc + 1u < N) { let i = (or + 3u) * N + (oc + 1u); Y_mta[i] = Y_mta[i] + acc31; } if (or + 3u < M && oc + 2u < N) { let i = (or + 3u) * N + (oc + 2u); Y_mta[i] = Y_mta[i] + acc32; } if (or + 3u < M && oc + 3u < N) { let i = (or + 3u) * N + (oc + 3u); Y_mta[i] = Y_mta[i] + acc33; } } // --- KERNEL: matmul_at --- // Y[K, N] = X^T @ B where X is [M, K] (so X^T is [K, M]), B is [M, N]. // dims layout: (K, N, M, _). // // A-tile holds X^T: tileA[k_local, m_local] = X[m_global, k_global]. // vec4 X load (4 K-consecutive values at fixed m). vec4 B load too. // Host MUST satisfy K % 4 == 0 AND N % 4 == 0. @group(0) @binding(0) var X_mat: array>; @group(0) @binding(1) var B_mat: array>; @group(0) @binding(2) var Y_mat: array; @group(0) @binding(3) var dims_mat: vec4; // (K, N, M, _) @compute @workgroup_size(16, 16, 1) fn matmul_at(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let K = dims_mat.x; let N = dims_mat.y; let M = dims_mat.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; // K-dim let block_col = wgid.x * TN; // N-dim var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (M + TK - 1u) / TK; // Prologue: M-tile 0 → buffer 0 { let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = aIm; let m_in = mG < M; let xv = X_mat[mG * K4 + kG / 4u]; tileA_db[(aIk + 0u) * TK_PAD + aIm] = select(0.0, xv.x, m_in && (kG + 0u) < K); tileA_db[(aIk + 1u) * TK_PAD + aIm] = select(0.0, xv.y, m_in && (kG + 1u) < K); tileA_db[(aIk + 2u) * TK_PAD + aIm] = select(0.0, xv.z, m_in && (kG + 2u) < K); tileA_db[(aIk + 3u) * TK_PAD + aIm] = select(0.0, xv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_mat[mGB * N4 + nG / 4u]; tileB_db[bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_mBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = nxt_mBase + aIm; let m_in = mG < M; let xv = X_mat[mG * K4 + kG / 4u]; tileA_db[nxt_a_off + (aIk + 0u) * TK_PAD + aIm] = select(0.0, xv.x, m_in && (kG + 0u) < K); tileA_db[nxt_a_off + (aIk + 1u) * TK_PAD + aIm] = select(0.0, xv.y, m_in && (kG + 1u) < K); tileA_db[nxt_a_off + (aIk + 2u) * TK_PAD + aIm] = select(0.0, xv.z, m_in && (kG + 2u) < K); tileA_db[nxt_a_off + (aIk + 3u) * TK_PAD + aIm] = select(0.0, xv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = nxt_mBase + bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_mat[mGB * N4 + nG / 4u]; tileB_db[nxt_b_off + bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } for (var m: u32 = 0u; m < TK; m = m + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + m]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + m]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + m]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + m]; let b0 = tileB_db[cur_b_off + m * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + m * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + m * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + m * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < K && oc + 0u < N) { Y_mat[(or + 0u) * N + (oc + 0u)] = acc00; } if (or + 0u < K && oc + 1u < N) { Y_mat[(or + 0u) * N + (oc + 1u)] = acc01; } if (or + 0u < K && oc + 2u < N) { Y_mat[(or + 0u) * N + (oc + 2u)] = acc02; } if (or + 0u < K && oc + 3u < N) { Y_mat[(or + 0u) * N + (oc + 3u)] = acc03; } if (or + 1u < K && oc + 0u < N) { Y_mat[(or + 1u) * N + (oc + 0u)] = acc10; } if (or + 1u < K && oc + 1u < N) { Y_mat[(or + 1u) * N + (oc + 1u)] = acc11; } if (or + 1u < K && oc + 2u < N) { Y_mat[(or + 1u) * N + (oc + 2u)] = acc12; } if (or + 1u < K && oc + 3u < N) { Y_mat[(or + 1u) * N + (oc + 3u)] = acc13; } if (or + 2u < K && oc + 0u < N) { Y_mat[(or + 2u) * N + (oc + 0u)] = acc20; } if (or + 2u < K && oc + 1u < N) { Y_mat[(or + 2u) * N + (oc + 1u)] = acc21; } if (or + 2u < K && oc + 2u < N) { Y_mat[(or + 2u) * N + (oc + 2u)] = acc22; } if (or + 2u < K && oc + 3u < N) { Y_mat[(or + 2u) * N + (oc + 3u)] = acc23; } if (or + 3u < K && oc + 0u < N) { Y_mat[(or + 3u) * N + (oc + 0u)] = acc30; } if (or + 3u < K && oc + 1u < N) { Y_mat[(or + 3u) * N + (oc + 1u)] = acc31; } if (or + 3u < K && oc + 2u < N) { Y_mat[(or + 3u) * N + (oc + 2u)] = acc32; } if (or + 3u < K && oc + 3u < N) { Y_mat[(or + 3u) * N + (oc + 3u)] = acc33; } } // --- KERNEL: matmul_at_acc --- // Y[K, N] += X^T @ B (same as matmul_at but accumulating). // vec4 tile loads — host MUST satisfy K % 4 == 0 AND N % 4 == 0. @group(0) @binding(0) var X_maa: array>; @group(0) @binding(1) var B_maa: array>; @group(0) @binding(2) var Y_maa: array; @group(0) @binding(3) var dims_maa: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_at_acc(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let K = dims_maa.x; let N = dims_maa.y; let M = dims_maa.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (M + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = aIm; let m_in = mG < M; let xv = X_maa[mG * K4 + kG / 4u]; tileA_db[(aIk + 0u) * TK_PAD + aIm] = select(0.0, xv.x, m_in && (kG + 0u) < K); tileA_db[(aIk + 1u) * TK_PAD + aIm] = select(0.0, xv.y, m_in && (kG + 1u) < K); tileA_db[(aIk + 2u) * TK_PAD + aIm] = select(0.0, xv.z, m_in && (kG + 2u) < K); tileA_db[(aIk + 3u) * TK_PAD + aIm] = select(0.0, xv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_maa[mGB * N4 + nG / 4u]; tileB_db[bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_mBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = nxt_mBase + aIm; let m_in = mG < M; let xv = X_maa[mG * K4 + kG / 4u]; tileA_db[nxt_a_off + (aIk + 0u) * TK_PAD + aIm] = select(0.0, xv.x, m_in && (kG + 0u) < K); tileA_db[nxt_a_off + (aIk + 1u) * TK_PAD + aIm] = select(0.0, xv.y, m_in && (kG + 1u) < K); tileA_db[nxt_a_off + (aIk + 2u) * TK_PAD + aIm] = select(0.0, xv.z, m_in && (kG + 2u) < K); tileA_db[nxt_a_off + (aIk + 3u) * TK_PAD + aIm] = select(0.0, xv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = nxt_mBase + bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_maa[mGB * N4 + nG / 4u]; tileB_db[nxt_b_off + bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } for (var m: u32 = 0u; m < TK; m = m + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + m]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + m]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + m]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + m]; let b0 = tileB_db[cur_b_off + m * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + m * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + m * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + m * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < K && oc + 0u < N) { let i = (or + 0u) * N + (oc + 0u); Y_maa[i] = Y_maa[i] + acc00; } if (or + 0u < K && oc + 1u < N) { let i = (or + 0u) * N + (oc + 1u); Y_maa[i] = Y_maa[i] + acc01; } if (or + 0u < K && oc + 2u < N) { let i = (or + 0u) * N + (oc + 2u); Y_maa[i] = Y_maa[i] + acc02; } if (or + 0u < K && oc + 3u < N) { let i = (or + 0u) * N + (oc + 3u); Y_maa[i] = Y_maa[i] + acc03; } if (or + 1u < K && oc + 0u < N) { let i = (or + 1u) * N + (oc + 0u); Y_maa[i] = Y_maa[i] + acc10; } if (or + 1u < K && oc + 1u < N) { let i = (or + 1u) * N + (oc + 1u); Y_maa[i] = Y_maa[i] + acc11; } if (or + 1u < K && oc + 2u < N) { let i = (or + 1u) * N + (oc + 2u); Y_maa[i] = Y_maa[i] + acc12; } if (or + 1u < K && oc + 3u < N) { let i = (or + 1u) * N + (oc + 3u); Y_maa[i] = Y_maa[i] + acc13; } if (or + 2u < K && oc + 0u < N) { let i = (or + 2u) * N + (oc + 0u); Y_maa[i] = Y_maa[i] + acc20; } if (or + 2u < K && oc + 1u < N) { let i = (or + 2u) * N + (oc + 1u); Y_maa[i] = Y_maa[i] + acc21; } if (or + 2u < K && oc + 2u < N) { let i = (or + 2u) * N + (oc + 2u); Y_maa[i] = Y_maa[i] + acc22; } if (or + 2u < K && oc + 3u < N) { let i = (or + 2u) * N + (oc + 3u); Y_maa[i] = Y_maa[i] + acc23; } if (or + 3u < K && oc + 0u < N) { let i = (or + 3u) * N + (oc + 0u); Y_maa[i] = Y_maa[i] + acc30; } if (or + 3u < K && oc + 1u < N) { let i = (or + 3u) * N + (oc + 1u); Y_maa[i] = Y_maa[i] + acc31; } if (or + 3u < K && oc + 2u < N) { let i = (or + 3u) * N + (oc + 2u); Y_maa[i] = Y_maa[i] + acc32; } if (or + 3u < K && oc + 3u < N) { let i = (or + 3u) * N + (oc + 3u); Y_maa[i] = Y_maa[i] + acc33; } } // ════════════════════════════════════════════════════════════ // Mixed-precision (f16 weight) variants // ════════════════════════════════════════════════════════════ // ════════════════════════════════════════════════════════════ // Mixed-precision variants (f16 weight storage, f32 acc + I/O) // ════════════════════════════════════════════════════════════ // // These are the same 64×64 / 4×4 sub-tile structure as the fp32 kernels // but the W binding is `array` instead of `array`. Cast to // f32 happens at load time into the workgroup tile; the inner loop and // accumulators stay fp32 (training-stable). Output Y is also f32 — // Phase 2 only quantizes weights. Phase 3 will add fp16 activations. // // Compiled only when adapter has `shader-f16` feature (engine prepends // `enable f16;` to the shared preamble dynamically). // --- KERNEL: matmul_w16 --- // Y[M, N] = X[M, K] @ W[K, N] where W is f16 // vec4 tile loads — host MUST satisfy K % 4 == 0 AND N % 4 == 0. @group(0) @binding(0) var X_w16: array>; @group(0) @binding(1) var W_w16: array>; @group(0) @binding(2) var Y_w16: array; @group(0) @binding(3) var dims_w16: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_w16(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_w16.x; let N = dims_w16.y; let K = dims_w16.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let xv = X_w16[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_w16[bwr * N4 + bwc / 4u]; tileB_db[bIk * TN + bIn + 0u] = select(0.0, f32(wv.x), bwr_in && (bwc + 0u) < N); tileB_db[bIk * TN + bIn + 1u] = select(0.0, f32(wv.y), bwr_in && (bwc + 1u) < N); tileB_db[bIk * TN + bIn + 2u] = select(0.0, f32(wv.z), bwr_in && (bwc + 2u) < N); tileB_db[bIk * TN + bIn + 3u] = select(0.0, f32(wv.w), bwr_in && (bwc + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X_w16[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = nxt_kBase + bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_w16[bwr * N4 + bwc / 4u]; tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, f32(wv.x), bwr_in && (bwc + 0u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, f32(wv.y), bwr_in && (bwc + 1u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, f32(wv.z), bwr_in && (bwc + 2u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, f32(wv.w), bwr_in && (bwc + 3u) < N); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < M && oc + 0u < N) { Y_w16[(or + 0u) * N + (oc + 0u)] = acc00; } if (or + 0u < M && oc + 1u < N) { Y_w16[(or + 0u) * N + (oc + 1u)] = acc01; } if (or + 0u < M && oc + 2u < N) { Y_w16[(or + 0u) * N + (oc + 2u)] = acc02; } if (or + 0u < M && oc + 3u < N) { Y_w16[(or + 0u) * N + (oc + 3u)] = acc03; } if (or + 1u < M && oc + 0u < N) { Y_w16[(or + 1u) * N + (oc + 0u)] = acc10; } if (or + 1u < M && oc + 1u < N) { Y_w16[(or + 1u) * N + (oc + 1u)] = acc11; } if (or + 1u < M && oc + 2u < N) { Y_w16[(or + 1u) * N + (oc + 2u)] = acc12; } if (or + 1u < M && oc + 3u < N) { Y_w16[(or + 1u) * N + (oc + 3u)] = acc13; } if (or + 2u < M && oc + 0u < N) { Y_w16[(or + 2u) * N + (oc + 0u)] = acc20; } if (or + 2u < M && oc + 1u < N) { Y_w16[(or + 2u) * N + (oc + 1u)] = acc21; } if (or + 2u < M && oc + 2u < N) { Y_w16[(or + 2u) * N + (oc + 2u)] = acc22; } if (or + 2u < M && oc + 3u < N) { Y_w16[(or + 2u) * N + (oc + 3u)] = acc23; } if (or + 3u < M && oc + 0u < N) { Y_w16[(or + 3u) * N + (oc + 0u)] = acc30; } if (or + 3u < M && oc + 1u < N) { Y_w16[(or + 3u) * N + (oc + 1u)] = acc31; } if (or + 3u < M && oc + 2u < N) { Y_w16[(or + 3u) * N + (oc + 2u)] = acc32; } if (or + 3u < M && oc + 3u < N) { Y_w16[(or + 3u) * N + (oc + 3u)] = acc33; } } // --- KERNEL: matmul_residual_w16 --- // Y[M, N] = X[M, K] @ W[K, N] + R[M, N] where W is f16 // vec4 tile loads — host MUST satisfy K % 4 == 0 AND N % 4 == 0. @group(0) @binding(0) var X_mrw: array>; @group(0) @binding(1) var W_mrw: array>; @group(0) @binding(2) var R_mrw: array; @group(0) @binding(3) var Y_mrw: array; @group(0) @binding(4) var dims_mrw: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_residual_w16(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mrw.x; let N = dims_mrw.y; let K = dims_mrw.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let xv = X_mrw[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mrw[bwr * N4 + bwc / 4u]; tileB_db[bIk * TN + bIn + 0u] = select(0.0, f32(wv.x), bwr_in && (bwc + 0u) < N); tileB_db[bIk * TN + bIn + 1u] = select(0.0, f32(wv.y), bwr_in && (bwc + 1u) < N); tileB_db[bIk * TN + bIn + 2u] = select(0.0, f32(wv.z), bwr_in && (bwc + 2u) < N); tileB_db[bIk * TN + bIn + 3u] = select(0.0, f32(wv.w), bwr_in && (bwc + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X_mrw[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = nxt_kBase + bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mrw[bwr * N4 + bwc / 4u]; tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, f32(wv.x), bwr_in && (bwc + 0u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, f32(wv.y), bwr_in && (bwc + 1u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, f32(wv.z), bwr_in && (bwc + 2u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, f32(wv.w), bwr_in && (bwc + 3u) < N); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < M && oc + 0u < N) { Y_mrw[(or + 0u) * N + (oc + 0u)] = acc00 + R_mrw[(or + 0u) * N + (oc + 0u)]; } if (or + 0u < M && oc + 1u < N) { Y_mrw[(or + 0u) * N + (oc + 1u)] = acc01 + R_mrw[(or + 0u) * N + (oc + 1u)]; } if (or + 0u < M && oc + 2u < N) { Y_mrw[(or + 0u) * N + (oc + 2u)] = acc02 + R_mrw[(or + 0u) * N + (oc + 2u)]; } if (or + 0u < M && oc + 3u < N) { Y_mrw[(or + 0u) * N + (oc + 3u)] = acc03 + R_mrw[(or + 0u) * N + (oc + 3u)]; } if (or + 1u < M && oc + 0u < N) { Y_mrw[(or + 1u) * N + (oc + 0u)] = acc10 + R_mrw[(or + 1u) * N + (oc + 0u)]; } if (or + 1u < M && oc + 1u < N) { Y_mrw[(or + 1u) * N + (oc + 1u)] = acc11 + R_mrw[(or + 1u) * N + (oc + 1u)]; } if (or + 1u < M && oc + 2u < N) { Y_mrw[(or + 1u) * N + (oc + 2u)] = acc12 + R_mrw[(or + 1u) * N + (oc + 2u)]; } if (or + 1u < M && oc + 3u < N) { Y_mrw[(or + 1u) * N + (oc + 3u)] = acc13 + R_mrw[(or + 1u) * N + (oc + 3u)]; } if (or + 2u < M && oc + 0u < N) { Y_mrw[(or + 2u) * N + (oc + 0u)] = acc20 + R_mrw[(or + 2u) * N + (oc + 0u)]; } if (or + 2u < M && oc + 1u < N) { Y_mrw[(or + 2u) * N + (oc + 1u)] = acc21 + R_mrw[(or + 2u) * N + (oc + 1u)]; } if (or + 2u < M && oc + 2u < N) { Y_mrw[(or + 2u) * N + (oc + 2u)] = acc22 + R_mrw[(or + 2u) * N + (oc + 2u)]; } if (or + 2u < M && oc + 3u < N) { Y_mrw[(or + 2u) * N + (oc + 3u)] = acc23 + R_mrw[(or + 2u) * N + (oc + 3u)]; } if (or + 3u < M && oc + 0u < N) { Y_mrw[(or + 3u) * N + (oc + 0u)] = acc30 + R_mrw[(or + 3u) * N + (oc + 0u)]; } if (or + 3u < M && oc + 1u < N) { Y_mrw[(or + 3u) * N + (oc + 1u)] = acc31 + R_mrw[(or + 3u) * N + (oc + 1u)]; } if (or + 3u < M && oc + 2u < N) { Y_mrw[(or + 3u) * N + (oc + 2u)] = acc32 + R_mrw[(or + 3u) * N + (oc + 2u)]; } if (or + 3u < M && oc + 3u < N) { Y_mrw[(or + 3u) * N + (oc + 3u)] = acc33 + R_mrw[(or + 3u) * N + (oc + 3u)]; } } // --- KERNEL: matmul_t_w16 --- // Y[M, N] = X[M, K] @ W^T where W stored as [N, K] in f16. // vec4 tile loads — host MUST satisfy K % 4 == 0 (X and W row strides = K). @group(0) @binding(0) var X_mtw: array>; @group(0) @binding(1) var W_mtw: array>; @group(0) @binding(2) var Y_mtw: array; @group(0) @binding(3) var dims_mtw: vec4; @compute @workgroup_size(16, 16, 1) fn matmul_t_w16(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mtw.x; let N = dims_mtw.y; let K = dims_mtw.z; let K4 = K / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let xv = X_mtw[axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIn = bI0 / TK; let bIk = bI0 % TK; let nG = block_col + bIn; let kG = bIk; let n_in = nG < N; let wv = W_mtw[nG * K4 + kG / 4u]; tileB_db[(bIk + 0u) * TN + bIn] = select(0.0, f32(wv.x), n_in && (kG + 0u) < K); tileB_db[(bIk + 1u) * TN + bIn] = select(0.0, f32(wv.y), n_in && (kG + 1u) < K); tileB_db[(bIk + 2u) * TN + bIn] = select(0.0, f32(wv.z), n_in && (kG + 2u) < K); tileB_db[(bIk + 3u) * TN + bIn] = select(0.0, f32(wv.w), n_in && (kG + 3u) < K); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let xv = X_mtw[axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, xv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, xv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, xv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, xv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIn = bI0 / TK; let bIk = bI0 % TK; let nG = block_col + bIn; let kG = nxt_kBase + bIk; let n_in = nG < N; let wv = W_mtw[nG * K4 + kG / 4u]; tileB_db[nxt_b_off + (bIk + 0u) * TN + bIn] = select(0.0, f32(wv.x), n_in && (kG + 0u) < K); tileB_db[nxt_b_off + (bIk + 1u) * TN + bIn] = select(0.0, f32(wv.y), n_in && (kG + 1u) < K); tileB_db[nxt_b_off + (bIk + 2u) * TN + bIn] = select(0.0, f32(wv.z), n_in && (kG + 2u) < K); tileB_db[nxt_b_off + (bIk + 3u) * TN + bIn] = select(0.0, f32(wv.w), n_in && (kG + 3u) < K); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < M && oc + 0u < N) { Y_mtw[(or + 0u) * N + (oc + 0u)] = acc00; } if (or + 0u < M && oc + 1u < N) { Y_mtw[(or + 0u) * N + (oc + 1u)] = acc01; } if (or + 0u < M && oc + 2u < N) { Y_mtw[(or + 0u) * N + (oc + 2u)] = acc02; } if (or + 0u < M && oc + 3u < N) { Y_mtw[(or + 0u) * N + (oc + 3u)] = acc03; } if (or + 1u < M && oc + 0u < N) { Y_mtw[(or + 1u) * N + (oc + 0u)] = acc10; } if (or + 1u < M && oc + 1u < N) { Y_mtw[(or + 1u) * N + (oc + 1u)] = acc11; } if (or + 1u < M && oc + 2u < N) { Y_mtw[(or + 1u) * N + (oc + 2u)] = acc12; } if (or + 1u < M && oc + 3u < N) { Y_mtw[(or + 1u) * N + (oc + 3u)] = acc13; } if (or + 2u < M && oc + 0u < N) { Y_mtw[(or + 2u) * N + (oc + 0u)] = acc20; } if (or + 2u < M && oc + 1u < N) { Y_mtw[(or + 2u) * N + (oc + 1u)] = acc21; } if (or + 2u < M && oc + 2u < N) { Y_mtw[(or + 2u) * N + (oc + 2u)] = acc22; } if (or + 2u < M && oc + 3u < N) { Y_mtw[(or + 2u) * N + (oc + 3u)] = acc23; } if (or + 3u < M && oc + 0u < N) { Y_mtw[(or + 3u) * N + (oc + 0u)] = acc30; } if (or + 3u < M && oc + 1u < N) { Y_mtw[(or + 3u) * N + (oc + 1u)] = acc31; } if (or + 3u < M && oc + 2u < N) { Y_mtw[(or + 3u) * N + (oc + 2u)] = acc32; } if (or + 3u < M && oc + 3u < N) { Y_mtw[(or + 3u) * N + (oc + 3u)] = acc33; } } // --- KERNEL: matmul_residual_swiglu_a --- // Y[M, N] = (silu(GATE) * UP)[M, K] @ W[K, N] + R[M, N] // // Fused FFN epilogue: collapses the 2-step // hidden = silu(gate) * up (swiglu_combine, dispatch 1) // Y = hidden @ W_down + R (matmul_residual, dispatch 2) // into a single dispatch. The A-tile load reads gate/up at the same index // and computes silu*up on the fly — the S×F `hidden` tensor is never // materialized. Saves -1 dispatch + 2*S*F bytes (hidden write + read) per // layer per forward pass. Backward analogue lives in matmul_at_swiglu_a. // // FFN dims: M=S, K=F (=d_ff), N=D. host MUST satisfy K%4 == 0 AND N%4 == 0. @group(0) @binding(0) var GATE_mrs: array>; @group(0) @binding(1) var UP_mrs: array>; @group(0) @binding(2) var W_mrs: array>; @group(0) @binding(3) var R_mrs: array; @group(0) @binding(4) var Y_mrs: array; @group(0) @binding(5) var dims_mrs: vec4; // (M, N, K, _) @compute @workgroup_size(16, 16, 1) fn matmul_residual_swiglu_a(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mrs.x; let N = dims_mrs.y; let K = dims_mrs.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; // Prologue: tile 0 → buffer 0. A-tile element = silu(gate) * up (fused). { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let gv = GATE_mrs[axr * K4 + axc / 4u]; let uv = UP_mrs [axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, silu(gv.x) * uv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, silu(gv.y) * uv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, silu(gv.z) * uv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, silu(gv.w) * uv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mrs[bwr * N4 + bwc / 4u]; tileB_db[bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N); tileB_db[bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N); tileB_db[bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N); tileB_db[bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let gv = GATE_mrs[axr * K4 + axc / 4u]; let uv = UP_mrs [axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, silu(gv.x) * uv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, silu(gv.y) * uv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, silu(gv.z) * uv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, silu(gv.w) * uv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = nxt_kBase + bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mrs[bwr * N4 + bwc / 4u]; tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, wv.x, bwr_in && (bwc + 0u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, wv.y, bwr_in && (bwc + 1u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, wv.z, bwr_in && (bwc + 2u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, wv.w, bwr_in && (bwc + 3u) < N); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < M && oc + 0u < N) { Y_mrs[(or + 0u) * N + (oc + 0u)] = acc00 + R_mrs[(or + 0u) * N + (oc + 0u)]; } if (or + 0u < M && oc + 1u < N) { Y_mrs[(or + 0u) * N + (oc + 1u)] = acc01 + R_mrs[(or + 0u) * N + (oc + 1u)]; } if (or + 0u < M && oc + 2u < N) { Y_mrs[(or + 0u) * N + (oc + 2u)] = acc02 + R_mrs[(or + 0u) * N + (oc + 2u)]; } if (or + 0u < M && oc + 3u < N) { Y_mrs[(or + 0u) * N + (oc + 3u)] = acc03 + R_mrs[(or + 0u) * N + (oc + 3u)]; } if (or + 1u < M && oc + 0u < N) { Y_mrs[(or + 1u) * N + (oc + 0u)] = acc10 + R_mrs[(or + 1u) * N + (oc + 0u)]; } if (or + 1u < M && oc + 1u < N) { Y_mrs[(or + 1u) * N + (oc + 1u)] = acc11 + R_mrs[(or + 1u) * N + (oc + 1u)]; } if (or + 1u < M && oc + 2u < N) { Y_mrs[(or + 1u) * N + (oc + 2u)] = acc12 + R_mrs[(or + 1u) * N + (oc + 2u)]; } if (or + 1u < M && oc + 3u < N) { Y_mrs[(or + 1u) * N + (oc + 3u)] = acc13 + R_mrs[(or + 1u) * N + (oc + 3u)]; } if (or + 2u < M && oc + 0u < N) { Y_mrs[(or + 2u) * N + (oc + 0u)] = acc20 + R_mrs[(or + 2u) * N + (oc + 0u)]; } if (or + 2u < M && oc + 1u < N) { Y_mrs[(or + 2u) * N + (oc + 1u)] = acc21 + R_mrs[(or + 2u) * N + (oc + 1u)]; } if (or + 2u < M && oc + 2u < N) { Y_mrs[(or + 2u) * N + (oc + 2u)] = acc22 + R_mrs[(or + 2u) * N + (oc + 2u)]; } if (or + 2u < M && oc + 3u < N) { Y_mrs[(or + 2u) * N + (oc + 3u)] = acc23 + R_mrs[(or + 2u) * N + (oc + 3u)]; } if (or + 3u < M && oc + 0u < N) { Y_mrs[(or + 3u) * N + (oc + 0u)] = acc30 + R_mrs[(or + 3u) * N + (oc + 0u)]; } if (or + 3u < M && oc + 1u < N) { Y_mrs[(or + 3u) * N + (oc + 1u)] = acc31 + R_mrs[(or + 3u) * N + (oc + 1u)]; } if (or + 3u < M && oc + 2u < N) { Y_mrs[(or + 3u) * N + (oc + 2u)] = acc32 + R_mrs[(or + 3u) * N + (oc + 2u)]; } if (or + 3u < M && oc + 3u < N) { Y_mrs[(or + 3u) * N + (oc + 3u)] = acc33 + R_mrs[(or + 3u) * N + (oc + 3u)]; } } // --- KERNEL: matmul_residual_swiglu_a_w16 --- // As matmul_residual_swiglu_a but W is f16 (mixed-precision forward). @group(0) @binding(0) var GATE_mrsw: array>; @group(0) @binding(1) var UP_mrsw: array>; @group(0) @binding(2) var W_mrsw: array>; @group(0) @binding(3) var R_mrsw: array; @group(0) @binding(4) var Y_mrsw: array; @group(0) @binding(5) var dims_mrsw: vec4; // (M, N, K, _) @compute @workgroup_size(16, 16, 1) fn matmul_residual_swiglu_a_w16(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let M = dims_mrsw.x; let N = dims_mrsw.y; let K = dims_mrsw.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (K + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = aIk; let row_in = axr < M; let gv = GATE_mrsw[axr * K4 + axc / 4u]; let uv = UP_mrsw [axr * K4 + axc / 4u]; tileA_db[aIm * TK_PAD + aIk + 0u] = select(0.0, silu(gv.x) * uv.x, row_in && (axc + 0u) < K); tileA_db[aIm * TK_PAD + aIk + 1u] = select(0.0, silu(gv.y) * uv.y, row_in && (axc + 1u) < K); tileA_db[aIm * TK_PAD + aIk + 2u] = select(0.0, silu(gv.z) * uv.z, row_in && (axc + 2u) < K); tileA_db[aIm * TK_PAD + aIk + 3u] = select(0.0, silu(gv.w) * uv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mrsw[bwr * N4 + bwc / 4u]; tileB_db[bIk * TN + bIn + 0u] = select(0.0, f32(wv.x), bwr_in && (bwc + 0u) < N); tileB_db[bIk * TN + bIn + 1u] = select(0.0, f32(wv.y), bwr_in && (bwc + 1u) < N); tileB_db[bIk * TN + bIn + 2u] = select(0.0, f32(wv.z), bwr_in && (bwc + 2u) < N); tileB_db[bIk * TN + bIn + 3u] = select(0.0, f32(wv.w), bwr_in && (bwc + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_kBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TK; let aIk = aI0 % TK; let axr = block_row + aIm; let axc = nxt_kBase + aIk; let row_in = axr < M; let gv = GATE_mrsw[axr * K4 + axc / 4u]; let uv = UP_mrsw [axr * K4 + axc / 4u]; tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 0u] = select(0.0, silu(gv.x) * uv.x, row_in && (axc + 0u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 1u] = select(0.0, silu(gv.y) * uv.y, row_in && (axc + 1u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 2u] = select(0.0, silu(gv.z) * uv.z, row_in && (axc + 2u) < K); tileA_db[nxt_a_off + aIm * TK_PAD + aIk + 3u] = select(0.0, silu(gv.w) * uv.w, row_in && (axc + 3u) < K); let bI0 = tid * 4u; let bIk = bI0 / TN; let bIn = bI0 % TN; let bwr = nxt_kBase + bIk; let bwc = block_col + bIn; let bwr_in = bwr < K; let wv = W_mrsw[bwr * N4 + bwc / 4u]; tileB_db[nxt_b_off + bIk * TN + bIn + 0u] = select(0.0, f32(wv.x), bwr_in && (bwc + 0u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 1u] = select(0.0, f32(wv.y), bwr_in && (bwc + 1u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 2u] = select(0.0, f32(wv.z), bwr_in && (bwc + 2u) < N); tileB_db[nxt_b_off + bIk * TN + bIn + 3u] = select(0.0, f32(wv.w), bwr_in && (bwc + 3u) < N); } for (var k: u32 = 0u; k < TK; k = k + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + k]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + k]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + k]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + k]; let b0 = tileB_db[cur_b_off + k * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + k * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + k * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + k * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < M && oc + 0u < N) { Y_mrsw[(or + 0u) * N + (oc + 0u)] = acc00 + R_mrsw[(or + 0u) * N + (oc + 0u)]; } if (or + 0u < M && oc + 1u < N) { Y_mrsw[(or + 0u) * N + (oc + 1u)] = acc01 + R_mrsw[(or + 0u) * N + (oc + 1u)]; } if (or + 0u < M && oc + 2u < N) { Y_mrsw[(or + 0u) * N + (oc + 2u)] = acc02 + R_mrsw[(or + 0u) * N + (oc + 2u)]; } if (or + 0u < M && oc + 3u < N) { Y_mrsw[(or + 0u) * N + (oc + 3u)] = acc03 + R_mrsw[(or + 0u) * N + (oc + 3u)]; } if (or + 1u < M && oc + 0u < N) { Y_mrsw[(or + 1u) * N + (oc + 0u)] = acc10 + R_mrsw[(or + 1u) * N + (oc + 0u)]; } if (or + 1u < M && oc + 1u < N) { Y_mrsw[(or + 1u) * N + (oc + 1u)] = acc11 + R_mrsw[(or + 1u) * N + (oc + 1u)]; } if (or + 1u < M && oc + 2u < N) { Y_mrsw[(or + 1u) * N + (oc + 2u)] = acc12 + R_mrsw[(or + 1u) * N + (oc + 2u)]; } if (or + 1u < M && oc + 3u < N) { Y_mrsw[(or + 1u) * N + (oc + 3u)] = acc13 + R_mrsw[(or + 1u) * N + (oc + 3u)]; } if (or + 2u < M && oc + 0u < N) { Y_mrsw[(or + 2u) * N + (oc + 0u)] = acc20 + R_mrsw[(or + 2u) * N + (oc + 0u)]; } if (or + 2u < M && oc + 1u < N) { Y_mrsw[(or + 2u) * N + (oc + 1u)] = acc21 + R_mrsw[(or + 2u) * N + (oc + 1u)]; } if (or + 2u < M && oc + 2u < N) { Y_mrsw[(or + 2u) * N + (oc + 2u)] = acc22 + R_mrsw[(or + 2u) * N + (oc + 2u)]; } if (or + 2u < M && oc + 3u < N) { Y_mrsw[(or + 2u) * N + (oc + 3u)] = acc23 + R_mrsw[(or + 2u) * N + (oc + 3u)]; } if (or + 3u < M && oc + 0u < N) { Y_mrsw[(or + 3u) * N + (oc + 0u)] = acc30 + R_mrsw[(or + 3u) * N + (oc + 0u)]; } if (or + 3u < M && oc + 1u < N) { Y_mrsw[(or + 3u) * N + (oc + 1u)] = acc31 + R_mrsw[(or + 3u) * N + (oc + 1u)]; } if (or + 3u < M && oc + 2u < N) { Y_mrsw[(or + 3u) * N + (oc + 2u)] = acc32 + R_mrsw[(or + 3u) * N + (oc + 2u)]; } if (or + 3u < M && oc + 3u < N) { Y_mrsw[(or + 3u) * N + (oc + 3u)] = acc33 + R_mrsw[(or + 3u) * N + (oc + 3u)]; } } // --- KERNEL: matmul_at_swiglu_a --- // Y[K, N] = (silu(GATE) * UP)^T @ B where the "X" of matmul_at is the // would-be hidden = silu(gate)*up, [M, K]. B is [M, N]. dims: (K, N, M, _). // // Backward analogue of matmul_residual_swiglu_a: computes dW_down without // ever materializing hidden. Replaces the matmul_at over L.hidden in the // FFN backward — combined with the forward fusion, the [S, F] hidden tensor // is eliminated end-to-end (no recompute needed; gate, up are already saved). // // FFN-bwd dims: K=F, N=D, M=S. host MUST satisfy K%4 == 0 AND N%4 == 0. @group(0) @binding(0) var GATE_mas: array>; @group(0) @binding(1) var UP_mas: array>; @group(0) @binding(2) var B_mas: array>; @group(0) @binding(3) var Y_mas: array; @group(0) @binding(4) var dims_mas: vec4; // (K, N, M, _) @compute @workgroup_size(16, 16, 1) fn matmul_at_swiglu_a(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let K = dims_mas.x; let N = dims_mas.y; let M = dims_mas.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; // K-dim (= F) let block_col = wgid.x * TN; // N-dim (= D) var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (M + TK - 1u) / TK; // Prologue: M-tile 0. X-tile element = silu(gate[mG, kG]) * up[mG, kG]. { let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = aIm; let m_in = mG < M; let gv = GATE_mas[mG * K4 + kG / 4u]; let uv = UP_mas [mG * K4 + kG / 4u]; tileA_db[(aIk + 0u) * TK_PAD + aIm] = select(0.0, silu(gv.x) * uv.x, m_in && (kG + 0u) < K); tileA_db[(aIk + 1u) * TK_PAD + aIm] = select(0.0, silu(gv.y) * uv.y, m_in && (kG + 1u) < K); tileA_db[(aIk + 2u) * TK_PAD + aIm] = select(0.0, silu(gv.z) * uv.z, m_in && (kG + 2u) < K); tileA_db[(aIk + 3u) * TK_PAD + aIm] = select(0.0, silu(gv.w) * uv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_mas[mGB * N4 + nG / 4u]; tileB_db[bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_mBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = nxt_mBase + aIm; let m_in = mG < M; let gv = GATE_mas[mG * K4 + kG / 4u]; let uv = UP_mas [mG * K4 + kG / 4u]; tileA_db[nxt_a_off + (aIk + 0u) * TK_PAD + aIm] = select(0.0, silu(gv.x) * uv.x, m_in && (kG + 0u) < K); tileA_db[nxt_a_off + (aIk + 1u) * TK_PAD + aIm] = select(0.0, silu(gv.y) * uv.y, m_in && (kG + 1u) < K); tileA_db[nxt_a_off + (aIk + 2u) * TK_PAD + aIm] = select(0.0, silu(gv.z) * uv.z, m_in && (kG + 2u) < K); tileA_db[nxt_a_off + (aIk + 3u) * TK_PAD + aIm] = select(0.0, silu(gv.w) * uv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = nxt_mBase + bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_mas[mGB * N4 + nG / 4u]; tileB_db[nxt_b_off + bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } for (var m: u32 = 0u; m < TK; m = m + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + m]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + m]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + m]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + m]; let b0 = tileB_db[cur_b_off + m * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + m * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + m * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + m * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < K && oc + 0u < N) { Y_mas[(or + 0u) * N + (oc + 0u)] = acc00; } if (or + 0u < K && oc + 1u < N) { Y_mas[(or + 0u) * N + (oc + 1u)] = acc01; } if (or + 0u < K && oc + 2u < N) { Y_mas[(or + 0u) * N + (oc + 2u)] = acc02; } if (or + 0u < K && oc + 3u < N) { Y_mas[(or + 0u) * N + (oc + 3u)] = acc03; } if (or + 1u < K && oc + 0u < N) { Y_mas[(or + 1u) * N + (oc + 0u)] = acc10; } if (or + 1u < K && oc + 1u < N) { Y_mas[(or + 1u) * N + (oc + 1u)] = acc11; } if (or + 1u < K && oc + 2u < N) { Y_mas[(or + 1u) * N + (oc + 2u)] = acc12; } if (or + 1u < K && oc + 3u < N) { Y_mas[(or + 1u) * N + (oc + 3u)] = acc13; } if (or + 2u < K && oc + 0u < N) { Y_mas[(or + 2u) * N + (oc + 0u)] = acc20; } if (or + 2u < K && oc + 1u < N) { Y_mas[(or + 2u) * N + (oc + 1u)] = acc21; } if (or + 2u < K && oc + 2u < N) { Y_mas[(or + 2u) * N + (oc + 2u)] = acc22; } if (or + 2u < K && oc + 3u < N) { Y_mas[(or + 2u) * N + (oc + 3u)] = acc23; } if (or + 3u < K && oc + 0u < N) { Y_mas[(or + 3u) * N + (oc + 0u)] = acc30; } if (or + 3u < K && oc + 1u < N) { Y_mas[(or + 3u) * N + (oc + 1u)] = acc31; } if (or + 3u < K && oc + 2u < N) { Y_mas[(or + 3u) * N + (oc + 2u)] = acc32; } if (or + 3u < K && oc + 3u < N) { Y_mas[(or + 3u) * N + (oc + 3u)] = acc33; } } // --- KERNEL: matmul_at_acc_swiglu_a --- // Accumulating variant of matmul_at_swiglu_a: Y[K, N] += (silu(GATE)*UP)^T @ B. // Used by the FFN backward when grad accumulation is on (mmAt='matmul_at_acc' path). // Identical to matmul_at_swiglu_a except the epilogue does Y[i] = Y[i] + acc. @group(0) @binding(0) var GATE_maas: array>; @group(0) @binding(1) var UP_maas: array>; @group(0) @binding(2) var B_maas: array>; @group(0) @binding(3) var Y_maas: array; @group(0) @binding(4) var dims_maas: vec4; // (K, N, M, _) @compute @workgroup_size(16, 16, 1) fn matmul_at_acc_swiglu_a(@builtin(workgroup_id) wgid: vec3, @builtin(local_invocation_id) lid: vec3) { let K = dims_maas.x; let N = dims_maas.y; let M = dims_maas.z; let K4 = K / 4u; let N4 = N / 4u; let tx = lid.x; let ty = lid.y; let tid = ty * 16u + tx; let block_row = wgid.y * TM; let block_col = wgid.x * TN; var acc00: f32 = 0.0; var acc01: f32 = 0.0; var acc02: f32 = 0.0; var acc03: f32 = 0.0; var acc10: f32 = 0.0; var acc11: f32 = 0.0; var acc12: f32 = 0.0; var acc13: f32 = 0.0; var acc20: f32 = 0.0; var acc21: f32 = 0.0; var acc22: f32 = 0.0; var acc23: f32 = 0.0; var acc30: f32 = 0.0; var acc31: f32 = 0.0; var acc32: f32 = 0.0; var acc33: f32 = 0.0; let nTiles = (M + TK - 1u) / TK; { let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = aIm; let m_in = mG < M; let gv = GATE_maas[mG * K4 + kG / 4u]; let uv = UP_maas [mG * K4 + kG / 4u]; tileA_db[(aIk + 0u) * TK_PAD + aIm] = select(0.0, silu(gv.x) * uv.x, m_in && (kG + 0u) < K); tileA_db[(aIk + 1u) * TK_PAD + aIm] = select(0.0, silu(gv.y) * uv.y, m_in && (kG + 1u) < K); tileA_db[(aIk + 2u) * TK_PAD + aIm] = select(0.0, silu(gv.z) * uv.z, m_in && (kG + 2u) < K); tileA_db[(aIk + 3u) * TK_PAD + aIm] = select(0.0, silu(gv.w) * uv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_maas[mGB * N4 + nG / 4u]; tileB_db[bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } workgroupBarrier(); for (var t: u32 = 0u; t < nTiles; t = t + 1u) { let parity = t & 1u; let cur_a_off = parity * TA_DB_HALF; let cur_b_off = parity * TB_DB_HALF; if (t + 1u < nTiles) { let nxt_a_off = (1u - parity) * TA_DB_HALF; let nxt_b_off = (1u - parity) * TB_DB_HALF; let nxt_mBase = (t + 1u) * TK; let aI0 = tid * 4u; let aIm = aI0 / TM; let aIk = aI0 % TM; let kG = block_row + aIk; let mG = nxt_mBase + aIm; let m_in = mG < M; let gv = GATE_maas[mG * K4 + kG / 4u]; let uv = UP_maas [mG * K4 + kG / 4u]; tileA_db[nxt_a_off + (aIk + 0u) * TK_PAD + aIm] = select(0.0, silu(gv.x) * uv.x, m_in && (kG + 0u) < K); tileA_db[nxt_a_off + (aIk + 1u) * TK_PAD + aIm] = select(0.0, silu(gv.y) * uv.y, m_in && (kG + 1u) < K); tileA_db[nxt_a_off + (aIk + 2u) * TK_PAD + aIm] = select(0.0, silu(gv.z) * uv.z, m_in && (kG + 2u) < K); tileA_db[nxt_a_off + (aIk + 3u) * TK_PAD + aIm] = select(0.0, silu(gv.w) * uv.w, m_in && (kG + 3u) < K); let bI0 = tid * 4u; let bIm = bI0 / TN; let bIn = bI0 % TN; let mGB = nxt_mBase + bIm; let nG = block_col + bIn; let mGB_in = mGB < M; let bv = B_maas[mGB * N4 + nG / 4u]; tileB_db[nxt_b_off + bIm * TN + bIn + 0u] = select(0.0, bv.x, mGB_in && (nG + 0u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 1u] = select(0.0, bv.y, mGB_in && (nG + 1u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 2u] = select(0.0, bv.z, mGB_in && (nG + 2u) < N); tileB_db[nxt_b_off + bIm * TN + bIn + 3u] = select(0.0, bv.w, mGB_in && (nG + 3u) < N); } for (var m: u32 = 0u; m < TK; m = m + 1u) { let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + m]; let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + m]; let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + m]; let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + m]; let b0 = tileB_db[cur_b_off + m * TN + (4u * tx + 0u)]; let b1 = tileB_db[cur_b_off + m * TN + (4u * tx + 1u)]; let b2 = tileB_db[cur_b_off + m * TN + (4u * tx + 2u)]; let b3 = tileB_db[cur_b_off + m * TN + (4u * tx + 3u)]; acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03); acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13); acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23); acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33); } workgroupBarrier(); } let or = block_row + 4u * ty; let oc = block_col + 4u * tx; if (or + 0u < K && oc + 0u < N) { let i = (or + 0u) * N + (oc + 0u); Y_maas[i] = Y_maas[i] + acc00; } if (or + 0u < K && oc + 1u < N) { let i = (or + 0u) * N + (oc + 1u); Y_maas[i] = Y_maas[i] + acc01; } if (or + 0u < K && oc + 2u < N) { let i = (or + 0u) * N + (oc + 2u); Y_maas[i] = Y_maas[i] + acc02; } if (or + 0u < K && oc + 3u < N) { let i = (or + 0u) * N + (oc + 3u); Y_maas[i] = Y_maas[i] + acc03; } if (or + 1u < K && oc + 0u < N) { let i = (or + 1u) * N + (oc + 0u); Y_maas[i] = Y_maas[i] + acc10; } if (or + 1u < K && oc + 1u < N) { let i = (or + 1u) * N + (oc + 1u); Y_maas[i] = Y_maas[i] + acc11; } if (or + 1u < K && oc + 2u < N) { let i = (or + 1u) * N + (oc + 2u); Y_maas[i] = Y_maas[i] + acc12; } if (or + 1u < K && oc + 3u < N) { let i = (or + 1u) * N + (oc + 3u); Y_maas[i] = Y_maas[i] + acc13; } if (or + 2u < K && oc + 0u < N) { let i = (or + 2u) * N + (oc + 0u); Y_maas[i] = Y_maas[i] + acc20; } if (or + 2u < K && oc + 1u < N) { let i = (or + 2u) * N + (oc + 1u); Y_maas[i] = Y_maas[i] + acc21; } if (or + 2u < K && oc + 2u < N) { let i = (or + 2u) * N + (oc + 2u); Y_maas[i] = Y_maas[i] + acc22; } if (or + 2u < K && oc + 3u < N) { let i = (or + 2u) * N + (oc + 3u); Y_maas[i] = Y_maas[i] + acc23; } if (or + 3u < K && oc + 0u < N) { let i = (or + 3u) * N + (oc + 0u); Y_maas[i] = Y_maas[i] + acc30; } if (or + 3u < K && oc + 1u < N) { let i = (or + 3u) * N + (oc + 1u); Y_maas[i] = Y_maas[i] + acc31; } if (or + 3u < K && oc + 2u < N) { let i = (or + 3u) * N + (oc + 2u); Y_maas[i] = Y_maas[i] + acc32; } if (or + 3u < K && oc + 3u < N) { let i = (or + 3u) * N + (oc + 3u); Y_maas[i] = Y_maas[i] + acc33; } }