bpe — GPU BPE Tokenizer Pipeline (Training + Inference)
File: bpe.wgsl Pipeline step: GPU-side pre-tokenization, training steps (pair counting, reduction, scan, compaction) and Trie-based inference. Feature: Subgroup cooperative scan (subgroup scan), O(1) Root LUT caching and local shared hash tables that drive atomic contention to zero.
Two large pipelines gathered into a single file:
- Training Pipeline (Training - 9 Dispatches):
clear_table→pair_count→find_max4→find_max_final_det→setup_merge→merge_reduce→scan_pass1→scan_pass2→scan_pass3→finalize_compact - Inference Pipeline (Inference - 3 Dispatches):
trie_tokenizer_chunked→trie_prefix_sum→trie_tokenizer_compact
Wait, what is this?
Say you want to shrink a piece of text without losing a single character. You're in an editor typing the same little phrase over and over: "and so on", "and so on"... At some point you go "alright, let me invent a shorthand for this, I'll just write aso and move on". That's exactly what BPE (Byte Pair Encoding) does: it finds the most frequent adjacent pair in the text and hands it one new symbol. Then it looks again, finds the new most-frequent pair, gives that a symbol too. The loop spins thousands of times and you end up with a "vocabulary" — common stuff becomes one piece, rare stuff stays letter by letter.
The training side is that loop running on the GPU. But there's a catch: "find the most frequent pair" means counting millions of character pairs. One thread ticking them off one by one would take forever. So instead millions of threads count at once — but when they all try to write into the same tally box (the global hash table) they trip over each other (contention). The fix is to have everyone keep a little tally at their own desk first (a workgroup-local hash table), then hand only the totals over to the shared ledger in one move.
The inference side is the reverse: the vocabulary is ready now, a fresh piece of text shows up, and you want to split it into as few pieces as possible. The trick here is to lay the vocabulary out as a Trie (prefix tree), the same shape a phone directory or autocomplete runs on. At each character you step one level down the tree, hunting for the "longest matching word" (greedy longest-match): "c", "ca", "cat"... however far you can walk, that's your token.
Everything else is the hardware dance around those two ideas: subgroup scans that cut barriers from 17 down to 2, tricks that cache the tree's root in shared memory so you never wait on global memory, branchless searches that keep threads in lockstep and dodge divergence. So "what's being counted / what's being looked up" is simple; the real work is doing it on the GPU without drowning.
What Does It Do?
1. GPU Pre-tokenization (bpe_word_boundary)
During BPE training, to prevent merges from forming across word boundaries (for example "yakınlık▁ve" → a single token), Turkish-compatible GPT-4 style pre-tokenization is applied.
The character class (Letter, Digit, Space, Punct, Newline) of each symbol is determined. At class transitions, that symbol's 16th bit is tagged with WORD_START_BIT (0x10000). When the BPE pair counting kernel sees this bit, it stops counting pairs across that boundary.
2. Two-Level Pair Counting (bpe_pair_count_b)
If millions of threads try to perform atomic insertion (atomicAdd) directly into the global hash table, a terrible contention and performance loss occurs.
To prevent this, a two-level count is performed:
- Threads accumulate data locally into a 1024-slot local shared hash table (
local_idsandlocal_counts) within their own workgroup, using Murmur3 hashing + quadratic probing. - Once the local reduction finishes, only the workgroup totals are transferred to the global hash table in a single move. This minimizes global atomic collisions.
3. Subgroup Cooperative Compaction (bpe_finalize_compact_b)
Blelloch exclusive scan algorithms traditionally use an 8-stage up-sweep, an 8-stage down-sweep and 17 workgroupBarriers.
bpe_finalize_compact_b resolves this heavy maneuver with a subgroup-cooperative scan:
- In-subgroup scans (
subgroupExclusiveAddandsubgroupAdd) run at the hardware level (barrier-free). - Only the subgroup totals are written to shared memory (
sh_sg_excl) and scanned at the workgroup level. - The 256-element exclusive scan is completed with only 2 barriers instead of 17 and 4 subgroup instructions.
- With branchless
select()andWORD_START_BITcarries, symbols are scattered to the destination in a single dispatch (scatter).
4. Trie-Based Fast Inference (trie_tokenizer_chunked)
To avoid memory latency when performing text inference on the GPU (trie binary search), the following superior optimizations are applied:
- Root LUT caching: The Trie's root node (Node 0) is queried again and again at the start of each character match. For the byte-level trie, the root node's 256 children are cached in shared memory as
root_lut. This way the first-level binary search is skipped entirely, providing O(1) direct access. - Depth-1 Cache: The
firstChild,numChildrenandtokenIdvalues of the node directly below the root (depth 1) are also cached in shared memory, zeroing out 3 global memory reads per token. - Branchless Search: In non-root node searches,
selectis used to prevent warp divergence (SIMD mask split). All threads run in lockstep (uniform execution).
Bind Group ABI
bpe_pair_count_b (4 bindings)
Counts batched pairs using the local shared hash table.
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | symbols: array<u32> — Input symbol array |
| 1 | storage, read_write | pair_counts: array<atomic<u32>> — Global pair counts |
| 2 | storage, read_write | pair_ids: array<atomic<u32>> — Global pair IDs (packed pair) |
| 3 | storage, read | state: IterState — symbol_count, table_size, etc. iteration state |
bpe_merge_reduce_b (4 bindings — Fused Merge Mark & Reduce)
Masks the best pair and sums valid symbol totals within the same workgroup via reduction.
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | symbols: array<u32> — Writes A-side merge marks |
| 1 | storage, read_write | valid_mask: array<u32> — `valid |
| 2 | storage, read_write | block_sums: array<u32> — Valid symbol total of each block |
| 3 | storage, read | state: IterState — symbol_a, symbol_b, new_symbol |
bpe_finalize_compact_b (5 bindings — Fused Scan + Merge Apply + Scatter)
Resolves offsets with a subgroup exclusive scan and scatters symbols to the destination by compacting them.
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | valid_mask: array<u32> — bit 0 = valid, bit 1 = merge intent |
| 1 | storage, read | block_sums: array<u32> — exclusive scanned block offsets |
| 2 | storage, read | input_symbols: array<u32> — Source symbol buffer |
| 3 | storage, read_write | output_symbols: array<u32> — Destination compacted symbol buffer |
| 4 | storage, read | state: IterState |
trie_tokenizer_chunked (6 bindings — Trie-based Inference)
Tokenizes the character sequence as greedy longest-match using a pre-compiled binary trie.
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | input: array<u32> — packed UTF-8 bytes (4 bytes per u32) |
| 1 | storage, read | nodes: array<u32> — Trie node info (firstChild, numChildren, tokenId) |
| 2 | storage, read | edges: array<u32> — Trie edge info (byte, childNodeIndex) |
| 3 | storage, read_write | token_output: array<u32> — Output token IDs |
| 4 | storage, read_write | chunk_counts: array<u32> — Number of tokens each chunk produced |
| 5 | uniform | params: TrieParams — input_length, chunk_size, max_tokens_per_chunk |
Line-by-Line Analysis
1) bpe_word_boundary — Byte Classification
fn char_class(tok: u32) -> u32 {
if (tok == 0x0Au) { return 4u; }
if (tok == 0x20u) { return 2u; }
if (tok - 0x30u <= 9u) { return 1u; }
if (tok >= 0x80u) { return 0u; }
if (tok - 0x61u <= 25u) { return 0u; }
if (tok - 0x41u <= 25u) { return 0u; }
return 3u;
}- Branch-less friendly subtraction trick: The
(tok - base) <= rangecomparison covers the entire ASCII range in a single operation. By reducing SIMD mask divergence, it keeps all of the hardware's threads in lockstep. - All UTF-8 continuation and leading bytes above
0x80uare marked directly as class 0 (Letter). This way Turkish characters (ı, ş, ğ, ç, ö, ü, etc.) are automatically placed into the letter class.
2) bpe_finalize_compact_b — Subgroup Cooperative Scan
let local_excl = subgroupExclusiveAdd(v);
let sg_total = subgroupAdd(v);
if (lane == 0u) { sh_sg_excl[sg_id] = sg_total; }
workgroupBarrier();
var t: u32 = 0u;
if (lane < NUM_SUBGROUPS) { t = sh_sg_excl[lane]; }
let cross_excl = subgroupExclusiveAdd(t);
if (sg_id == 0u && lane < NUM_SUBGROUPS) { sh_sg_excl[lane] = cross_excl; }
workgroupBarrier();
let final_excl = sh_sg_excl[sg_id] + local_excl;- Instead of a traditional 256-element Blelloch exclusive scan, hardware-level subgroup intrinsics are used:
- With
subgroupExclusiveAdd(v), the in-subgroup exclusive sum is resolved quickly (barrier-free). - Each subgroup's total (
sg_total) is dropped into shared memory (sh_sg_excl). - All threads run one more exclusive subgroup scan over these 8 subgroup totals in shared memory (
cross_excl). - All that remains is to add its own subgroup base offset to the local offset.
- With
- Cost: Just 2 barriers instead of 17, plus hardware subgroup ALU units. The speedup is enormous.
3) trie_tokenizer_chunked — Root LUT and Depth-1 Cache
let nc = min(cached_root_nc, MAX_CACHED_EDGES);
let fc = cached_root_fc;
if (lid.x < nc) {
let sym = edges[(fc + lid.x) * 2u] & 0xFFu;
let d1_node = edges[(fc + lid.x) * 2u + 1u];
root_lut[sym] = d1_node;
d1_fc[sym] = nodes[d1_node * 3u];
d1_nc[sym] = nodes[d1_node * 3u + 1u] & 0xFFFFu;
d1_tid[sym] = nodes[d1_node * 3u + 2u];
}
workgroupBarrier();- Each token match starts from the Trie root node (Node 0).
- All of the root's branches (
root_lut) and the metadata of the depth-1 nodes (d1_fc,d1_nc,d1_tid) are cached in shared memory. - In the search loop, when
cn == 0u(the transition from depth 0 to 1), the search drops to O(1) speed:nn = root_lut[bv]; - In the transition from depth 1 to 2, the shared cache is read instead of global memory pointer-chasing:These two steps completely eliminate the global memory bottleneck at the most frequently triggered search levels.
nn = find_child_global(d1_fc[rb], d1_nc[rb], bv);
Code Review
Finding 1: bpe_pair_count_b Local Table Sizing and Occupancy Relationship
| Risk | Description |
|---|---|
| 🟡 perf | LOCAL_TABLE_SIZE was set to 1024 slots (8 KB) in shared memory. When making this size 2048 slots (16 KB) was tried, it was found that on Apple M GPUs the occupancy ratio dropped by half (4 WG/SM → 2 WG/SM) and, due to the loss of latency hiding, the step time slowed down by 70%. The current balance of 1024 slots (25% fill + Murmur3 + quadratic probing) gives the most optimal hardware occupancy profile. |
Finding 2: trie_tokenizer_compact Coalesced Memory Writes
| Risk | Description |
|---|---|
| 🟢 Resolved | The old design's 1 thread = 1 chunk approach caused consecutive threads to write with stride-512 jumps over memory (broken coalescing). In the trie_tokenizer_compact kernel, however, a 1 workgroup = 1 chunk mapping was made; 256 threads cooperatively copy consecutive write indices (compact_output[db + i]), maximizing coalesced memory access success. |
Finding 3: Eliminating the PCIe Bottleneck (trie_prefix_sum)
| Risk | Description |
|---|---|
| 🟢 Resolved | The prefix sum operation required to compact the tokenizer outputs was, in the old structure, pulled to the CPU via mapAsync, scanned on the CPU and loaded back. trie_prefix_sum was moved to the GPU as a single-thread (workgroup_size(1)) kernel, zeroing the CPU roundtrip so that 2 PCIe transfer latencies and 1 GPU fence wait are completely eliminated. |
Quick Checklist
| Test Scenario | Status |
|---|---|
| Are word boundaries scanned correctly? | ✅ symbols[id] | WORD_START_BIT |
What if subgroup operations are not supported by the hardware? | ⚠️ Requires a training fallback mechanism on Safari/Webkit |
| Is warp divergence prevented in the trie search? | ✅ select()-based branchless lower_bound |
| Is the local hash table modulo power-of-2? | ✅ LOCAL_TABLE_MASK = 1023 |
| Can the Blelloch scan climb up to 1B symbols? | ✅ Hierarchical 3-pass scan |
| Is the root node LUT cache safe? | ✅ 256 byte ASCII LUT |
Host Architecture and Coordination (trie.js & tokenizer.js)
The superior speed of the BPE tokenizer comes not only from the GPU shaders, but also from the intelligent memory management and asynchronous execution architecture on the host (JS) side.
1. BFS-Based Cache-Friendly Trie Compilation (compileVocabToTrie)
After the in-memory Trie tree is built from the raw word list (byte sequences), it is serialized as a flat binary (v3 format).
- BFS Ordering (Breadth-First Search): When the tree is flattened, BFS order is followed. This way a node's children are placed consecutively in memory. The GPU's L1/L2 caches reach maximum hit rate when accessing these consecutive indices.
- Sorted Children: Child edges are sorted by symbol byte value from smallest to largest. This sorting enables the
find_child_globalhelper function on the GPU side to perform a uniform binary search in non-root node searches.
2. Persistent Buffer Pooling
In WebGPU, building and destroying a GPUBuffer from scratch (garbage collection) on every tokenization call, plus the GPU allocation overhead, is a major source of latency.
TrieTokenizer solves this problem with an amortized O(1) buffer pool:
- Persistent buffers (
#inputBuf,#tokenBuf,#countsBuf,#offsetsBuf,#totalBuf,#compactBuf) are built once based on the size of the input text and re-used across all subsequentencodecalls (re-use). - Only when the newly arriving text is larger than the current pool capacity is the pool grown by 1.5x and reallocated.
3. Dynamic Slicing (Multi-pass Slicing)
WebGPU has a hardware maxBufferSize limit. If a very large text is tried to be loaded onto the GPU all at once, a memory overflow occurs.
TrieTokenizeranalyzes the incoming text size against GPU limits.- If the input exceeds the limits, it splits the text into safe slices (
sliceSize = Math.floor(maxInputPerPass / chunkSize) * chunkSize) and processes it safely and stably with consecutive dispatches.
4. Single Submit and Zero-CPU Latency (Single Submit Command Encoding)
One of the most critical optimizations is in the asynchronous command communication between CPU and GPU:
- The 3 separate passes (
trie_tokenizer_chunked→trie_prefix_sum→trie_tokenizer_compact) are combined into a singleCommandEncoderand sent to the GPU with a single submit (device.queue.submit). - This way the CPU never waits for the GPU during intermediate steps (zero CPU-GPU roundtrip). The GPU pushes the commands through the pipeline uninterrupted.
- Smart DMA Copy: Before compaction, the full size of the output is not known (because the number of tokens each chunk produces is variable). To avoid locking up CPU-GPU synchronization, the 4-byte
total_tokenscount is computed on the GPU first and pulled to the host near-instantly via mapAsync (because the work is already finished). Then a full-size DMAcopyBufferToBufferis triggered to read only the compacted, real token sequence. Memory waste is zeroed out.
Next
To return to the index again: index.md.