/** * @file bpe.wgsl * @brief GPU BPE kernels — training + inference (WebGPU/WGSL) * * Single-file kernel collection. Each kernel is delimited by * "// --- KERNEL: ---" markers; the JS host (engine.js) splits at * those markers and compiles each kernel as its own GPUShaderModule, * prepending the shared preamble (everything before the first marker). * * Training kernels (run during BPETrainer.train): * 1. bpe_word_boundary — GPU pre-tokenization (word boundaries) * 2. bpe_clear_table — Hash table reset * 3. bpe_find_max_pair4 — Block-level max reduction (4 elem/thread) * 4. bpe_find_max_pair_final_det — Final max reduction (deterministic) * 5. bpe_setup_merge — GPU-driven merge orchestrator * 6. bpe_pair_count_b — Batched two-level pair counting * 7. bpe_merge_reduce_b — FUSED merge mark + reduce * 8. bpe_scan_pass1_local — Hierarchical scan: per-chunk local scan * 9. bpe_scan_pass2_chunks — Hierarchical scan: chunk_sums + update_count * 10. bpe_scan_pass3_apply — Hierarchical scan: apply chunk offsets * 11. bpe_finalize_compact_b — Fused scan + merge_apply + scatter * * Inference kernels (run during TrieTokenizer encode): * 12. trie_tokenizer_chunked — Chunked greedy longest-match tokenization * 13. trie_prefix_sum — GPU exclusive prefix sum over chunks * 14. trie_tokenizer_compact — Cooperative compaction * * Training pipeline (9 dispatches/iteration, batched): * clear_table → pair_count → find_max4 → find_max_final_det → setup_merge * → merge_reduce (FUSED) * → scan_pass1 → scan_pass2 → scan_pass3 * → finalize_compact (FUSED merge_apply) */ // ════════════════════════════════════════════════════════════ // SHARED UTILITIES (prepended to every kernel module) // ════════════════════════════════════════════════════════════ // Enable subgroup ops module-wide. Required by bpe_finalize_compact_b's // subgroup-cooperative scan. Must precede any declarations per WGSL spec. enable subgroups; const WORKGROUP_SIZE: u32 = 256u; const MAX_PROBE: u32 = 128u; const MAX_WG_DIM: u32 = 65535u; // WebGPU maxComputeWorkgroupsPerDimension const INVALID_TOKEN: u32 = 0xFFFFFFFFu; const WORD_START_BIT: u32 = 0x10000u; // bit 16 = word-start flag const TOKEN_MASK: u32 = 0xFFFFu; // lower 16 bits = token ID fn pack_pair(a: u32, b: u32) -> u32 { return (a << 16u) | b; } fn unpack_first(pair: u32) -> u32 { return pair >> 16u; } fn unpack_second(pair: u32) -> u32 { return pair & 0xFFFFu; } // GPU-driven iteration state for batched training. // Written by bpe_setup_merge, read by all batched kernels. struct IterState { symbol_count: u32, // [0] Current valid symbol count table_size: u32, // [1] Hash table size (constant) early_stop: u32, // [2] 1=stop, 0=continue next_token_id: u32, // [3] Next token ID to assign symbol_a: u32, // [4] First symbol of best pair symbol_b: u32, // [5] Second symbol of best pair new_symbol: u32, // [6] Merged symbol ID max_count: u32, // [7] Count of best pair merges_done: u32, // [8] Merges completed in this batch max_symbols: u32, // [9] Max dispatch size (initial count) _pad1: u32, _pad2: u32, } /// Murmur3 integer finalizer — 6 ALU ops instead of FNV-1a's 16. /// Returns raw hash; caller applies `& table_mask` for power-of-2 tables. fn pair_hash(pair: u32) -> u32 { var x = pair; x = (x ^ (x >> 16u)) * 0x7feb352du; x = (x ^ (x >> 15u)) * 0x846ca68bu; return x ^ (x >> 16u); } /// Linearize a 2D dispatch grid into a flat 1D thread index. /// JS host splits large dispatches into (X, Y, 1) where X*Y = total workgroups. /// Each kernel should use flat_id() instead of gid.x to support >16M threads. fn flat_id(gid: vec3, nwg: vec3) -> u32 { return gid.x + gid.y * nwg.x * WORKGROUP_SIZE; } // Local hash table constants for bpe_pair_count_b. // 1024 slots in shared memory (8 KB total: 2×1024×4 B). A 2048-slot // version was tried but regressed by ~70% on Apple M GPUs: doubling // shared mem per WG halved occupancy (4 WGs/SM → 2 WGs/SM), which // killed latency hiding. The original 25% load factor + Murmur3 + // quadratic probing already converges fast. const LOCAL_TABLE_SIZE: u32 = 1024u; const LOCAL_TABLE_MASK: u32 = 1023u; // power-of-2 modulo const LOCAL_MAX_PROBE: u32 = 64u; // Hierarchical scan constants (used by bpe_scan_pass1/2/3) // SCAN_PER_THREAD = SCAN_CHUNK_SIZE / WORKGROUP_SIZE; both must be // kept in sync with the workgroup<> array sizes below the marker. const SCAN_CHUNK_SIZE: u32 = 2048u; const SCAN_PER_THREAD: u32 = 8u; // Subgroup constants — assumed 32-lane subgroups (Apple M*, NVIDIA, // Intel, AMD-on-Chrome). NUM_SUBGROUPS = WORKGROUP_SIZE / SUBGROUP_SIZE. // Used by bpe_finalize_compact_b's subgroup-cooperative exclusive scan. const SUBGROUP_SIZE: u32 = 32u; const NUM_SUBGROUPS: u32 = 8u; /// Deterministic comparison: higher count wins; ties broken by smaller pair_id. /// Ensures identical vocabulary output regardless of GPU scheduling order. fn is_better(count_new: u32, pair_new: u32, count_old: u32, pair_old: u32) -> bool { return count_new > count_old || (count_new == count_old && pair_new < pair_old); } // ════════════════════════════════════════════════════════════ // TRAINING KERNELS // ════════════════════════════════════════════════════════════ // --- KERNEL: bpe_word_boundary --- // // GPU pre-tokenization: classify each symbol's byte value into a character // class (letter/digit/space/punct), then compare adjacent classes to detect // word boundaries. When a boundary is found, bit 16 (WORD_START_BIT) is set // on the symbol at the start of the new word. // // This ensures bpe_pair_count never counts pairs across word boundaries, // preventing multi-word token formation (e.g. "yakınlık▁ve▁" → 1 token). // // Character classes: // 0 = letter (a-z, 0xC0-0x24F, Turkish ğışçöüİĞŞÇÖÜ, Arabic, etc.) // 1 = digit (0-9) // 2 = space (0x20 = ▁) // 3 = punctuation / other // 4 = newline (always a word boundary) struct WordBoundaryParams { symbol_count: u32, _pad: u32 } @group(0) @binding(0) var symbols: array; @group(0) @binding(1) var params: WordBoundaryParams; /// Character classification with reduced branch divergence. /// Uses unsigned subtraction trick: `(tok - base) <= range` is branchless-friendly /// and covers the entire range in a single comparison. fn char_class(tok: u32) -> u32 { // Newline — own class (always a word boundary) if (tok == 0x0Au) { return 4u; } // Space if (tok == 0x20u) { return 2u; } // Digit 0-9 (0x30..0x39) if (tok - 0x30u <= 9u) { return 1u; } // UTF-8 continuation + leading bytes (0x80-0xFF) — all treated as letter // Covers multi-byte chars, Turkish İĞŞÇÖÜ, Arabic, etc. if (tok >= 0x80u) { return 0u; } // ASCII letter a-z (0x61..0x7A) if (tok - 0x61u <= 25u) { return 0u; } // ASCII letter A-Z (0x41..0x5A) if (tok - 0x41u <= 25u) { return 0u; } // Everything else = punctuation return 3u; } /** * @compute Kernel: bpe_word_boundary * * Scans the symbol sequence to identify word boundaries based on character classes. * When a boundary is detected, the WORD_START_BIT (bit 16) is set on that symbol. * This prevents the BPE pairing logic from merging tokens across distinct words, * following GPT-style pre-tokenization rules. * * Boundary Rules: * 1. The first symbol of the entire sequence is always a word start. * 2. Any transition between character classes (e.g., letter to punctuation) is a boundary. * 3. Special Case: A space (class 2) followed by a letter (0) or digit (1) is NOT a * boundary; the space "attaches" to the start of the following word. * 4. Special Case: A newline (class 4) always forces a boundary both before and after. */ @compute @workgroup_size(256) fn bpe_word_boundary( @builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3 ) { let id = flat_id(gid, nwg); if (id >= params.symbol_count) { return; } let tok = symbols[id] & TOKEN_MASK; let cls = char_class(tok); // First symbol is always a word start if (id == 0u) { symbols[id] = tok | WORD_START_BIT; return; } let prev_tok = symbols[id - 1u] & TOKEN_MASK; let prev_cls = char_class(prev_tok); // Default: Word boundary exists if the character class changes var is_boundary = cls != prev_cls; // GPT-4 style: space(2) followed by letter(0) or digit(1) = same word (space attaches to next word). if (prev_cls == 2u && (cls == 0u || cls == 1u)) { is_boundary = false; } // But a space itself starting after a non-space is always a word start if (cls == 2u && prev_cls != 2u) { is_boundary = true; } // Newline(4) transitions always represent boundaries if (prev_cls == 4u || cls == 4u) { is_boundary = true; } if (is_boundary) { symbols[id] = tok | WORD_START_BIT; } // else: tok stays as-is (lower 16 bits only) } // --- KERNEL: bpe_clear_table --- struct ClearParams { table_size: u32, _pad: u32 } @group(0) @binding(0) var pair_counts: array; @group(0) @binding(1) var pair_ids: array; @group(0) @binding(2) var params: ClearParams; @compute @workgroup_size(256) fn bpe_clear_table(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { let id = flat_id(gid, nwg); if (id >= params.table_size) { return; } pair_counts[id] = 0u; pair_ids[id] = 0u; } // --- KERNEL: bpe_find_max_pair4 --- // // 4 elements/thread max reduction with deterministic tie-breaking. // Each thread loads 4 hash table entries, performs thread-local max, // then enters shared-memory reduction. // // Benefits: // - 4× fewer workgroups → 4× fewer block_max entries // - Deterministic: same count → smaller pair_id wins (reproducible) // - Same occupancy (256 threads/wg), better utilization // // Coverage: 256 threads × 4 elements = 1024 entries/workgroup. struct FindMaxParams { table_size: u32, _pad: u32 } @group(0) @binding(0) var pair_counts: array; @group(0) @binding(1) var pair_ids: array; @group(0) @binding(2) var block_max_counts: array; @group(0) @binding(3) var block_max_pair_ids: array; @group(0) @binding(4) var params: FindMaxParams; var sh_c: array; var sh_p: array; @compute @workgroup_size(256) fn bpe_find_max_pair4( @builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3, @builtin(workgroup_id) tgid: vec3, @builtin(num_workgroups) nwg: vec3 ) { let block_idx = tgid.x + tgid.y * nwg.x; // ── Thread-local max over 4 elements ── let base = (gid.x + gid.y * nwg.x * WORKGROUP_SIZE) * 4u; var best_c: u32 = 0u; var best_p: u32 = 0u; for (var e: u32 = 0u; e < 4u; e++) { let idx = base + e; if (idx < params.table_size) { let c = pair_counts[idx]; let p = pair_ids[idx]; if (is_better(c, p, best_c, best_p)) { best_c = c; best_p = p; } } } sh_c[lid.x] = best_c; sh_p[lid.x] = best_p; workgroupBarrier(); // ── Shared-memory reduction with deterministic tie-breaking ── for (var s: u32 = 128u; s > 0u; s >>= 1u) { if (lid.x < s) { if (is_better(sh_c[lid.x + s], sh_p[lid.x + s], sh_c[lid.x], sh_p[lid.x])) { sh_c[lid.x] = sh_c[lid.x + s]; sh_p[lid.x] = sh_p[lid.x + s]; } } workgroupBarrier(); } if (lid.x == 0u) { block_max_counts[block_idx] = sh_c[0]; block_max_pair_ids[block_idx] = sh_p[0]; } } // --- KERNEL: bpe_find_max_pair_final_det --- // // Final max reduction with deterministic tie-breaking. // Same count → smaller pair_id wins → reproducible vocabulary. struct FinalParams { block_count: u32, _pad: u32 } @group(0) @binding(0) var block_max_counts: array; @group(0) @binding(1) var block_max_pair_ids: array; @group(0) @binding(2) var max_count: array; @group(0) @binding(3) var max_pair_id: array; @group(0) @binding(4) var params: FinalParams; var sh_c: array; var sh_p: array; @compute @workgroup_size(256) fn bpe_find_max_pair_final_det(@builtin(local_invocation_id) lid: vec3) { var lm: u32 = 0u; var lp: u32 = 0u; var i: u32 = lid.x; while (i < params.block_count) { let c = block_max_counts[i]; let p = block_max_pair_ids[i]; if (is_better(c, p, lm, lp)) { lm = c; lp = p; } i += WORKGROUP_SIZE; } sh_c[lid.x] = lm; sh_p[lid.x] = lp; workgroupBarrier(); for (var s: u32 = 128u; s > 0u; s >>= 1u) { if (lid.x < s) { if (is_better(sh_c[lid.x + s], sh_p[lid.x + s], sh_c[lid.x], sh_p[lid.x])) { sh_c[lid.x] = sh_c[lid.x + s]; sh_p[lid.x] = sh_p[lid.x + s]; } } workgroupBarrier(); } if (lid.x == 0u) { max_count[0] = sh_c[0]; max_pair_id[0] = sh_p[0]; } } // ════════════════════════════════════════════════════════════ // BATCHED TRAINING KERNELS — GPU-driven merge loop // // These "_b" variants read iteration parameters from a shared // IterState storage buffer instead of per-dispatch uniforms, // enabling N merges to be encoded in a single command buffer // with zero CPU readbacks between iterations. // ════════════════════════════════════════════════════════════ // --- KERNEL: bpe_setup_merge --- // // Orchestrator (single thread). Runs after findMax each iteration. // Reads best pair, writes merge params, increments token counter, // checks early stop, logs the merge for CPU reconstruction. @group(0) @binding(0) var sm_max_count: array; @group(0) @binding(1) var sm_max_pair_id: array; @group(0) @binding(2) var state: IterState; @group(0) @binding(3) var merge_log: array; @compute @workgroup_size(1) fn bpe_setup_merge(@builtin(global_invocation_id) gid: vec3) { if (state.early_stop != 0u) { return; } let mc = sm_max_count[0]; if (mc < 2u || state.next_token_id > TOKEN_MASK) { state.early_stop = 1u; return; } let pair = sm_max_pair_id[0]; state.symbol_a = pair >> 16u; state.symbol_b = pair & 0xFFFFu; state.max_count = mc; state.new_symbol = state.next_token_id; // Log for CPU vocab reconstruction: [pair, newTokenId, count] let log_idx = state.merges_done * 3u; merge_log[log_idx] = pair; merge_log[log_idx + 1u] = state.next_token_id; merge_log[log_idx + 2u] = mc; state.next_token_id += 1u; state.merges_done += 1u; } // --- KERNEL: bpe_pair_count_b --- @group(0) @binding(0) var symbols: array; @group(0) @binding(1) var pair_counts: array>; @group(0) @binding(2) var pair_ids: array>; @group(0) @binding(3) var state: IterState; // Sizes match LOCAL_TABLE_SIZE in the preamble — keep in sync. var local_ids: array, 1024>; var local_counts: array, 1024>; @compute @workgroup_size(256) fn bpe_pair_count_b( @builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3, @builtin(num_workgroups) nwg: vec3 ) { // Clear 1024 entries with 256 threads (4 entries/thread) for (var s: u32 = 0u; s < LOCAL_TABLE_SIZE; s += WORKGROUP_SIZE) { atomicStore(&local_ids[lid.x + s], 0u); atomicStore(&local_counts[lid.x + s], 0u); } workgroupBarrier(); if (state.early_stop != 0u) { return; } // Phase 1: Aggregate into LOCAL shared-memory table let id = flat_id(gid, nwg); if (id + 1u < state.symbol_count) { let raw_b = symbols[id + 1u]; if ((raw_b & WORD_START_BIT) == 0u) { let a = symbols[id] & TOKEN_MASK; let b = raw_b & TOKEN_MASK; if (a != 0u && b != 0u) { let pid = pack_pair(a, b); let h = pair_hash(pid); // Murmur3 finalizer — better avalanche than Knuth for (var probe: u32 = 0u; probe < LOCAL_MAX_PROBE; probe++) { let idx = (h + (probe * (probe + 1u)) / 2u) & LOCAL_TABLE_MASK; let r = atomicCompareExchangeWeak(&local_ids[idx], 0u, pid); if (r.exchanged || r.old_value == pid) { atomicAdd(&local_counts[idx], 1u); break; } } } } } workgroupBarrier(); // Phase 2: Flush local table → global table for (var slot: u32 = lid.x; slot < LOCAL_TABLE_SIZE; slot += WORKGROUP_SIZE) { let cnt = atomicLoad(&local_counts[slot]); if (cnt == 0u) { continue; } let pid = atomicLoad(&local_ids[slot]); if (pid == 0u) { continue; } let table_mask = state.table_size - 1u; let hash = pair_hash(pid) & table_mask; for (var probe: u32 = 0u; probe < MAX_PROBE; probe++) { let idx = (hash + (probe * (probe + 1u)) / 2u) & table_mask; let r = atomicCompareExchangeWeak(&pair_ids[idx], 0u, pid); if (r.exchanged || r.old_value == pid) { atomicAdd(&pair_counts[idx], cnt); break; } } } } // --- KERNEL: bpe_merge_reduce_b --- // // FUSED merge + prefix_sum_reduce: eliminates 1 dispatch + N×4 byte global read. // // Phase 1: Each thread performs merge logic (A-side write, B-side validity). // The valid bit stays in register — NO separate valid_mask read pass // needed for the reduction. // Phase 2: Workgroup-local sum reduction of valid bits → block_sums. // // valid_mask is still WRITTEN because finalize_compact_b reads it. // // Bindings (superset of original merge + reduce): // 0: symbols (read_write) — merge reads neighbors, A-side writes merged symbol // 1: valid_mask (read_write) — written for finalize_compact to consume // 2: block_sums (read_write) — workgroup reduction output // 3: state (read) — IterState with symbol_a, symbol_b, new_symbol, count @group(0) @binding(0) var symbols: array; @group(0) @binding(1) var valid_mask: array; @group(0) @binding(2) var block_sums: array; @group(0) @binding(3) var state: IterState; var sh_reduce: array; // A cooperative shared-memory cache for symbols[id-1/id/id+1] was tried // here and reverted: Apple's L1/L2 already coalesces these consecutive // reads, so the redundant memory traffic was a phantom — the cache plus // the extra workgroupBarrier was a net ~22% regression. @compute @workgroup_size(256) fn bpe_merge_reduce_b( @builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3, @builtin(workgroup_id) tgid: vec3, @builtin(num_workgroups) nwg: vec3 ) { // NOTE: Cannot early-return before workgroupBarrier() — WGSL requires // uniform control flow at barriers. Use guard flag instead. let stopped = state.early_stop != 0u; let id = flat_id(gid, nwg); let block_idx = tgid.x + tgid.y * nwg.x; // ── Phase 1: Merge MARK pass (read-only on symbols) ── // // Race-free design: this kernel ONLY READS from `symbols`. The merged // symbol is materialized later inside `bpe_finalize_compact_b` (which // runs after the scan and gets cross-workgroup ordering for free via // WebGPU's storage-buffer hazard tracking between dispatches). // // Bit-packed valid_mask, both bits consumed by finalize_compact_b: // bit 0 — self-validity flag (drives the per-block scan + scatter gate) // bit 1 — A-side merge intent (selects merged-token vs. raw on scatter) var valid: u32 = 0u; // default 0 for out-of-bounds or stopped threads var merge_a: u32 = 0u; // bit 1 → A-side merge intent if (!stopped && id < state.symbol_count) { let raw = symbols[id]; let raw_prev = select(0u, symbols[id - 1u], id > 0u); let raw_next = select(0u, symbols[id + 1u], id + 1u < state.symbol_count); // A-side: mark this position for merge if (id, id+1) matches the winning pair if (id + 1u < state.symbol_count && (raw_next & WORD_START_BIT) == 0u && (raw & TOKEN_MASK) == state.symbol_a && (raw_next & TOKEN_MASK) == state.symbol_b) { merge_a = 2u; // bit 1 } // B-side self-validity (unchanged logic) valid = 1u; if (id > 0u && (raw & WORD_START_BIT) == 0u && (raw_prev & TOKEN_MASK) == state.symbol_a && (raw & TOKEN_MASK) == state.symbol_b) { valid = 0u; } // Both bits consumed by finalize_compact_b on the scatter path. valid_mask[id] = valid | merge_a; } // ── Phase 2: Workgroup reduction (replaces bpe_prefix_sum_reduce_b) ── // valid bit is already in register — NO global read of valid_mask. // All threads (including stopped ones) participate in barriers. sh_reduce[lid.x] = valid; workgroupBarrier(); for (var s: u32 = 128u; s > 0u; s >>= 1u) { if (lid.x < s) { sh_reduce[lid.x] += sh_reduce[lid.x + s]; } workgroupBarrier(); } if (lid.x == 0u && !stopped) { block_sums[block_idx] = sh_reduce[0]; } } // ════════════════════════════════════════════════════════════ // HIERARCHICAL BLELLOCH SCAN — 3 passes // // Replaces the old sequential _b kernel (one workgroup, one thread, // O(N) iterations) and the limited _par kernel (single WG, capped at // 256 blocks). Handles up to 1B symbols in 2 levels: // // pass1: each WG scans SCAN_CHUNK_SIZE consecutive block_sums in // shared memory; emits per-chunk total to chunk_sums[wg_id] // pass2: single WG scans chunk_sums and stages new symbol_count + // indirect dispatch (mirrors the old _par kernel's tail) // pass3: each thread adds its chunk's offset back to data[] // // Capacity: pass2 handles up to SCAN_CHUNK_SIZE chunks → max // SCAN_CHUNK_SIZE² = 4M block_sums = ~1B symbols. For BPE training // at 256B-corpus scale this is plenty. // (SCAN_CHUNK_SIZE / SCAN_PER_THREAD live in the shared preamble.) // ════════════════════════════════════════════════════════════ // --- KERNEL: bpe_scan_pass1_local --- // // Workgroup-local exclusive scan of SCAN_CHUNK_SIZE consecutive elements. // Bindings: // 0: data (rw) — block_sums input/output (in-place exclusive scan) // 1: chunk_sums (rw) — per-WG chunk total, written by thread 0 // 2: state (r) — IterState (reads symbol_count + early_stop) @group(0) @binding(0) var data1: array; @group(0) @binding(1) var chunk_sums1: array; @group(0) @binding(2) var state1: IterState; var sh_scan1: array; @compute @workgroup_size(WORKGROUP_SIZE) fn bpe_scan_pass1_local( @builtin(local_invocation_id) lid: vec3, @builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3 ) { let stopped = state1.early_stop != 0u; let block_count = select( (state1.symbol_count + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE, 0u, stopped ); let wg_flat = wgid.x + wgid.y * nwg.x; let base = wg_flat * SCAN_CHUNK_SIZE; // Load — 8 elements per thread, strided for (var i: u32 = 0u; i < SCAN_PER_THREAD; i++) { let local_idx = lid.x + i * WORKGROUP_SIZE; let global_idx = base + local_idx; var v: u32 = 0u; if (global_idx < block_count) { v = data1[global_idx]; } sh_scan1[local_idx] = v; } workgroupBarrier(); // Up-sweep (reduce) for (var d: u32 = 1u; d < SCAN_CHUNK_SIZE; d <<= 1u) { let stride = d << 1u; var i: u32 = lid.x; while (i < SCAN_CHUNK_SIZE / stride) { let idx = (i + 1u) * stride - 1u; sh_scan1[idx] += sh_scan1[idx - d]; i += WORKGROUP_SIZE; } workgroupBarrier(); } // Save total + clear root var total: u32 = 0u; if (lid.x == 0u) { total = sh_scan1[SCAN_CHUNK_SIZE - 1u]; sh_scan1[SCAN_CHUNK_SIZE - 1u] = 0u; } workgroupBarrier(); // Down-sweep for (var d: u32 = SCAN_CHUNK_SIZE >> 1u; d > 0u; d >>= 1u) { let stride = d << 1u; var i: u32 = lid.x; while (i < SCAN_CHUNK_SIZE / stride) { let idx = (i + 1u) * stride - 1u; let t = sh_scan1[idx - d]; sh_scan1[idx - d] = sh_scan1[idx]; sh_scan1[idx] += t; i += WORKGROUP_SIZE; } workgroupBarrier(); } // Write scanned values back for (var i: u32 = 0u; i < SCAN_PER_THREAD; i++) { let local_idx = lid.x + i * WORKGROUP_SIZE; let global_idx = base + local_idx; if (global_idx < block_count) { data1[global_idx] = sh_scan1[local_idx]; } } // Write per-chunk total (thread 0 only) if (lid.x == 0u) { chunk_sums1[wg_flat] = total; } } // --- KERNEL: bpe_scan_pass2_chunks --- // // Single-WG exclusive scan of chunk_sums + stage new symbol_count and // indirect dispatch (replaces the old _par kernel's update_count tail). // Bindings: // 0: chunk_sums (rw) — scanned in place // 1: state (rw) — symbol_count gets the new total // 2: indirect (rw) — [wgX, wgY, wgZ] for next iteration's pair_count @group(0) @binding(0) var chunk_sums2: array; @group(0) @binding(1) var state2: IterState; @group(0) @binding(2) var indirect2: array; var sh_scan2: array; @compute @workgroup_size(WORKGROUP_SIZE) fn bpe_scan_pass2_chunks(@builtin(local_invocation_id) lid: vec3) { let stopped = state2.early_stop != 0u; let block_count = select( (state2.symbol_count + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE, 0u, stopped ); let chunk_count = (block_count + SCAN_CHUNK_SIZE - 1u) / SCAN_CHUNK_SIZE; // Load chunk_sums (8/thread) for (var i: u32 = 0u; i < SCAN_PER_THREAD; i++) { let idx = lid.x + i * WORKGROUP_SIZE; var v: u32 = 0u; if (idx < chunk_count) { v = chunk_sums2[idx]; } sh_scan2[idx] = v; } workgroupBarrier(); // Up-sweep for (var d: u32 = 1u; d < SCAN_CHUNK_SIZE; d <<= 1u) { let stride = d << 1u; var i: u32 = lid.x; while (i < SCAN_CHUNK_SIZE / stride) { let idx = (i + 1u) * stride - 1u; sh_scan2[idx] += sh_scan2[idx - d]; i += WORKGROUP_SIZE; } workgroupBarrier(); } var total: u32 = 0u; if (lid.x == 0u) { total = sh_scan2[SCAN_CHUNK_SIZE - 1u]; sh_scan2[SCAN_CHUNK_SIZE - 1u] = 0u; } workgroupBarrier(); // Down-sweep for (var d: u32 = SCAN_CHUNK_SIZE >> 1u; d > 0u; d >>= 1u) { let stride = d << 1u; var i: u32 = lid.x; while (i < SCAN_CHUNK_SIZE / stride) { let idx = (i + 1u) * stride - 1u; let t = sh_scan2[idx - d]; sh_scan2[idx - d] = sh_scan2[idx]; sh_scan2[idx] += t; i += WORKGROUP_SIZE; } workgroupBarrier(); } // Write back for (var i: u32 = 0u; i < SCAN_PER_THREAD; i++) { let idx = lid.x + i * WORKGROUP_SIZE; if (idx < chunk_count) { chunk_sums2[idx] = sh_scan2[idx]; } } // Stage new symbol_count + indirect (thread 0 only, skip on early_stop) if (lid.x == 0u && !stopped) { state2._pad1 = total; state2.symbol_count = total; let total_wg = (total + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE; if (total_wg <= MAX_WG_DIM) { indirect2[0] = max(total_wg, 1u); indirect2[1] = 1u; } else { indirect2[0] = MAX_WG_DIM; indirect2[1] = (total_wg + MAX_WG_DIM - 1u) / MAX_WG_DIM; } indirect2[2] = 1u; } } // --- KERNEL: bpe_scan_pass3_apply --- // // Add per-chunk offsets back to data[]. Each thread handles one element: // data[i] += chunk_sums[i / SCAN_CHUNK_SIZE] // Bindings: // 0: data (rw) // 1: chunk_sums (r) // 2: state (r) @group(0) @binding(0) var data3: array; @group(0) @binding(1) var chunk_sums3: array; @group(0) @binding(2) var state3: IterState; @compute @workgroup_size(WORKGROUP_SIZE) fn bpe_scan_pass3_apply( @builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3 ) { if (state3.early_stop != 0u) { return; } let block_count = (state3.symbol_count + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE; let id = flat_id(gid, nwg); if (id >= block_count) { return; } let chunk_idx = id / SCAN_CHUNK_SIZE; data3[id] += chunk_sums3[chunk_idx]; } // --- KERNEL: bpe_finalize_compact_b --- // // Fused prefix_sum_finalize + merge_apply + compact: // 1. Read valid_mask once, extract both the validity bit (bit 0) and // the A-side merge intent (bit 1) into registers. // 2. Blelloch exclusive scan over validity bits in shared memory → // gives the in-block output offset. // 3. Scatter to output_symbols. A-side merge positions emit the // newly-formed token; non-merge valid positions emit the original // symbol unchanged. B-side / invalid positions are dropped. // // This absorbs what was previously a separate bpe_merge_apply dispatch // (whose only job was to overwrite symbols[A-side] in place before // compaction). One fewer dispatch + one fewer global symbols read. // // Bindings: // 0: valid_mask (read) — bit 0 = valid, bit 1 = A-side merge intent // 1: block_sums (read) — per-block totals from prefix_sum_reduce // 2: input_symbols (read) — source symbol buffer (pre-merge) // 3: output_symbols (rw) — destination symbol buffer (post-merge, compacted) // 4: state (read) — IterState with symbol_count + new_symbol @group(0) @binding(0) var valid_mask: array; @group(0) @binding(1) var block_sums: array; @group(0) @binding(2) var input_symbols: array; @group(0) @binding(3) var output_symbols: array; @group(0) @binding(4) var state: IterState; // Cross-subgroup partial sums + exclusive offsets (8 entries for 8 subgroups). var sh_sg_excl: array; @compute @workgroup_size(256) fn bpe_finalize_compact_b( @builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3, @builtin(workgroup_id) tgid: vec3, @builtin(num_workgroups) nwg: vec3 ) { let fid = flat_id(gid, nwg); let block_idx = tgid.x + tgid.y * nwg.x; let sg_id = lid.x / SUBGROUP_SIZE; // 0..7 let lane = lid.x % SUBGROUP_SIZE; // 0..31 // Load mask once: bit 0 = valid, bit 1 = A-side merge intent. var mask: u32 = 0u; if (fid < state.symbol_count) { mask = valid_mask[fid]; } let v = mask & 1u; // ── Subgroup-cooperative exclusive scan over 256 elements ── // // Replaces the 8-stage Blelloch up-sweep + 8-stage down-sweep + 17 // workgroupBarriers with 2 barriers + 4 subgroup ops. Logic: // 1. Each subgroup runs subgroupExclusiveAdd → in-subgroup offsets, // and subgroupAdd → subgroup totals. // 2. Lane 0 of each subgroup writes its total to sh_sg_excl. // 3. ALL subgroups exclusive-scan the 8 totals (uniform CF; // redundant-but-cheap work — saves the non-uniform-CF rejection // that single-subgroup gating would trigger). // 4. Subgroup 0's first 8 lanes write the cross-subgroup offsets // back to sh_sg_excl (only the writes from sg_id == 0 commit; // others are gated, no race). // 5. Each thread's final exclusive offset = sh_sg_excl[sg_id] + // in-subgroup local exclusive sum. let local_excl = subgroupExclusiveAdd(v); let sg_total = subgroupAdd(v); if (lane == 0u) { sh_sg_excl[sg_id] = sg_total; } workgroupBarrier(); var t: u32 = 0u; if (lane < NUM_SUBGROUPS) { t = sh_sg_excl[lane]; } let cross_excl = subgroupExclusiveAdd(t); if (sg_id == 0u && lane < NUM_SUBGROUPS) { sh_sg_excl[lane] = cross_excl; } workgroupBarrier(); let final_excl = sh_sg_excl[sg_id] + local_excl; // ── Scatter with on-the-fly merge ── // Valid positions land in the compacted output. A-side merge positions // emit `new_symbol | flag` (where flag is the WORD_START_BIT carried // over from the original A-side symbol); non-merge positions copy the // input symbol verbatim. Branchless via select(). if (fid < state.symbol_count && v == 1u) { let dest = block_sums[block_idx] + final_excl; let raw = input_symbols[fid]; let merged = state.new_symbol | (raw & WORD_START_BIT); let is_a_merge = (mask & 2u) != 0u; output_symbols[dest] = select(raw, merged, is_a_merge); } } // ════════════════════════════════════════════════════════════ // INFERENCE KERNELS — trie tokenizer // ════════════════════════════════════════════════════════════ // --- KERNEL: trie_tokenizer_chunked --- // // Chunked greedy longest-match tokenization using a pre-compiled binary trie. // // Optimization: Cache root node's edge table in shared memory. // Root (node 0) is accessed on EVERY token match restart — caching its // edges eliminates repeated global memory pointer-chasing for the // most frequently accessed trie level. With a byte-level trie, // root has up to 256 children → 512 u32s (2KB) in shared memory. const MAX_CACHED_EDGES: u32 = 256u; struct TrieParams { input_length: u32, chunk_size: u32, max_tokens_per_chunk: u32, _pad: u32 } @group(0) @binding(0) var input: array; // packed: 4 bytes per u32 (LE) @group(0) @binding(1) var nodes: array; // 3 x u32 per node @group(0) @binding(2) var edges: array; // 2 x u32 per edge @group(0) @binding(3) var token_output: array; @group(0) @binding(4) var chunk_counts: array; @group(0) @binding(5) var params: TrieParams; /// Extract a single byte from the packed input buffer. /// 4 bytes per u32, little-endian: byte 0 at bits [0:7], byte 3 at bits [24:31]. /// Uses hardware extractBits for bit extraction. fn read_byte(pos: u32) -> u32 { return extractBits(input[pos >> 2u], (pos & 3u) * 8u, 8u); } // O(1) Root LUT: Direct byte→node lookup table in shared memory. // Replaces binary search (up to 8 iterations, warp divergence) // with a single branchless array read. var root_lut: array; var cached_root_fc: u32; // root firstChild var cached_root_nc: u32; // root numChildren // Depth-1 metadata cache: saves 3 global reads per token. // Every token starts at root (depth 0, cached via root_lut) then // transitions to depth 1. Caching the depth-1 node's firstChild, // numChildren, and tokenId eliminates 3 global reads per token. // Indexed by the root byte that led to the depth-1 node. // Cost: 3KB shared memory (256 × 3 × 4B). var d1_fc: array; // depth-1 firstChild var d1_nc: array; // depth-1 numChildren var d1_tid: array; // depth-1 tokenId /// Find child using global memory (non-root nodes). /// Branchless lower_bound: `select()` compiles to predication — no SIMD /// mask split, no warp divergence. All threads execute the same number /// of iterations (⌈log₂(num)⌉), trading early-exit for uniform execution. fn find_child_global(first: u32, num: u32, sym: u32) -> u32 { var lo: u32 = 0u; var n: u32 = num; while (n > 0u) { let half = n >> 1u; let mid = lo + half; let less = (edges[(first + mid) * 2u] & 0xFFu) < sym; lo = select(lo, mid + 1u, less); n = select(half, n - half - 1u, less); } if (lo < num) { let slot = (first + lo) * 2u; if ((edges[slot] & 0xFFu) == sym) { return edges[slot + 1u]; } } return INVALID_TOKEN; } @compute @workgroup_size(256) fn trie_tokenizer_chunked( @builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3 ) { // ── Build shared-memory caches: 256 threads fill 256 slots ── // Step 1: Initialize all LUT + depth-1 cache slots root_lut[lid.x] = INVALID_TOKEN; d1_fc[lid.x] = 0u; d1_nc[lid.x] = 0u; d1_tid[lid.x] = INVALID_TOKEN; if (lid.x == 0u) { cached_root_fc = nodes[0]; // root.firstChild cached_root_nc = nodes[1] & 0xFFFFu; // root.numChildren } workgroupBarrier(); // Step 2: Scatter valid root edges into LUT + populate depth-1 cache let nc = min(cached_root_nc, MAX_CACHED_EDGES); let fc = cached_root_fc; if (lid.x < nc) { let sym = edges[(fc + lid.x) * 2u] & 0xFFu; let d1_node = edges[(fc + lid.x) * 2u + 1u]; // depth-1 node index root_lut[sym] = d1_node; // byte → node (O(1)) d1_fc[sym] = nodes[d1_node * 3u]; // cache firstChild d1_nc[sym] = nodes[d1_node * 3u + 1u] & 0xFFFFu; // cache numChildren d1_tid[sym] = nodes[d1_node * 3u + 2u]; // cache tokenId } workgroupBarrier(); // ── Main tokenization loop ── let id = gid.x; let cs = id * params.chunk_size; if (cs >= params.input_length) { chunk_counts[id] = 0u; return; } let ce = min(cs + params.chunk_size, params.input_length); let ob = id * params.max_tokens_per_chunk; var tw: u32 = 0u; var pos = cs; // Register cache for packed input — avoids re-reading the same u32 // from global memory when sequential bytes fall within one word. var cached_word_idx: u32 = 0xFFFFFFFFu; var cached_word: u32 = 0u; while (pos < ce && tw < params.max_tokens_per_chunk) { var cn: u32 = 0u; var lmt: u32 = INVALID_TOKEN; var lmp = pos; var wp = pos; var depth: u32 = 0u; var rb: u32 = 0u; // root byte — identifies which depth-1 cache entry to use while (wp < ce) { // Read byte with register cache let word_idx = wp >> 2u; if (word_idx != cached_word_idx) { cached_word = input[word_idx]; cached_word_idx = word_idx; } let bv = extractBits(cached_word, (wp & 3u) * 8u, 8u); var nn: u32; if (cn == 0u) { // Depth 0 → 1: O(1) root LUT nn = root_lut[bv]; rb = bv; } else if (depth == 1u) { // Depth 1 → 2: shared-memory cached metadata (3 global reads saved) nn = find_child_global(d1_fc[rb], d1_nc[rb], bv); } else { // Depth 2+: global memory lookup let nfc = nodes[cn * 3u]; let nnc = nodes[cn * 3u + 1u] & 0xFFFFu; nn = find_child_global(nfc, nnc, bv); } if (nn == INVALID_TOKEN) { break; } cn = nn; wp++; depth++; // tokenId: use cached value at depth 1, global otherwise let ti = select(nodes[cn * 3u + 2u], d1_tid[rb], depth == 1u); if (ti != INVALID_TOKEN) { lmt = ti; lmp = wp; } } if (lmt != INVALID_TOKEN) { token_output[ob + tw] = lmt; tw++; pos = lmp; } else { // Fallback byte — use read_byte (cache already warm in most cases) token_output[ob + tw] = read_byte(pos); tw++; pos++; } } chunk_counts[id] = tw; } // --- KERNEL: trie_prefix_sum --- // // GPU-side exclusive prefix sum over chunk_counts. // Eliminates the CPU roundtrip that previously required: // 1. mapAsync readback of chunk_counts (numChunks × 4 bytes) // 2. CPU prefix sum loop // 3. Upload of chunk_offsets (numChunks × 4 bytes) // // Single-thread sequential scan is sufficient because: // - ~100K chunks for 50MB input = ~0.1ms on GPU // - The real win is removing 2 PCIe transfers + 1 GPU fence // // Outputs total_tokens[0] so the host only reads back 4 bytes // to allocate the correctly-sized compact buffer. struct PrefixSumTrieParams { num_chunks: u32, _pad: u32 } @group(0) @binding(0) var chunk_counts: array; @group(0) @binding(1) var chunk_offsets: array; @group(0) @binding(2) var total_tokens: array; @group(0) @binding(3) var params: PrefixSumTrieParams; @compute @workgroup_size(1) fn trie_prefix_sum(@builtin(global_invocation_id) gid: vec3) { if (gid.x != 0u) { return; } var sum: u32 = 0u; for (var i: u32 = 0u; i < params.num_chunks; i++) { chunk_offsets[i] = sum; sum += chunk_counts[i]; } total_tokens[0] = sum; } // --- KERNEL: trie_tokenizer_compact --- // // Cooperative compaction: 1 workgroup (256 threads) = 1 chunk. // All threads in a workgroup write to consecutive addresses → coalesced // memory access. Previous design (1 thread = 1 chunk) caused stride-512 // writes across the workgroup, destroying memory bandwidth. struct CompactTrieParams { max_tokens_per_chunk: u32, _pad: u32 } @group(0) @binding(0) var chunked_tokens: array; @group(0) @binding(1) var chunk_counts: array; @group(0) @binding(2) var chunk_offsets: array; @group(0) @binding(3) var compact_output: array; @group(0) @binding(4) var params: CompactTrieParams; @compute @workgroup_size(256) fn trie_tokenizer_compact( @builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, @builtin(num_workgroups) nwg: vec3 ) { // Linearize 2D workgroup grid (needed when numChunks > 65535) let chunk_id = wid.x + wid.y * nwg.x; let cnt = chunk_counts[chunk_id]; if (cnt == 0u) { return; } let sb = chunk_id * params.max_tokens_per_chunk; let db = chunk_offsets[chunk_id]; // 256 threads cooperatively copy — consecutive lid.x = consecutive addresses (coalesced) for (var i = lid.x; i < cnt; i += 256u) { compact_output[db + i] = chunked_tokens[sb + i]; } }