llm.istanbul·Study
TR EN
Workbench →

bpe — GPU BPE Tokenizer Pipeline (Training + Inference)

File: bpe.wgsl Pipeline step: GPU-side pre-tokenization, training steps (pair counting, reduction, scan, compaction) and Trie-based inference. Feature: Subgroup cooperative scan (subgroup scan), O(1) Root LUT caching and local shared hash tables that drive atomic contention to zero.

Two large pipelines gathered into a single file:

  • Training Pipeline (Training - 9 Dispatches): clear_tablepair_countfind_max4find_max_final_detsetup_mergemerge_reducescan_pass1scan_pass2scan_pass3finalize_compact
  • Inference Pipeline (Inference - 3 Dispatches): trie_tokenizer_chunkedtrie_prefix_sumtrie_tokenizer_compact

Wait, what is this?

Say you want to shrink a piece of text without losing a single character. You're in an editor typing the same little phrase over and over: "and so on", "and so on"... At some point you go "alright, let me invent a shorthand for this, I'll just write aso and move on". That's exactly what BPE (Byte Pair Encoding) does: it finds the most frequent adjacent pair in the text and hands it one new symbol. Then it looks again, finds the new most-frequent pair, gives that a symbol too. The loop spins thousands of times and you end up with a "vocabulary" — common stuff becomes one piece, rare stuff stays letter by letter.

The training side is that loop running on the GPU. But there's a catch: "find the most frequent pair" means counting millions of character pairs. One thread ticking them off one by one would take forever. So instead millions of threads count at once — but when they all try to write into the same tally box (the global hash table) they trip over each other (contention). The fix is to have everyone keep a little tally at their own desk first (a workgroup-local hash table), then hand only the totals over to the shared ledger in one move.

The inference side is the reverse: the vocabulary is ready now, a fresh piece of text shows up, and you want to split it into as few pieces as possible. The trick here is to lay the vocabulary out as a Trie (prefix tree), the same shape a phone directory or autocomplete runs on. At each character you step one level down the tree, hunting for the "longest matching word" (greedy longest-match): "c", "ca", "cat"... however far you can walk, that's your token.

Everything else is the hardware dance around those two ideas: subgroup scans that cut barriers from 17 down to 2, tricks that cache the tree's root in shared memory so you never wait on global memory, branchless searches that keep threads in lockstep and dodge divergence. So "what's being counted / what's being looked up" is simple; the real work is doing it on the GPU without drowning.


What Does It Do?

1. GPU Pre-tokenization (bpe_word_boundary)

During BPE training, to prevent merges from forming across word boundaries (for example "yakınlık▁ve" → a single token), Turkish-compatible GPT-4 style pre-tokenization is applied. The character class (Letter, Digit, Space, Punct, Newline) of each symbol is determined. At class transitions, that symbol's 16th bit is tagged with WORD_START_BIT (0x10000). When the BPE pair counting kernel sees this bit, it stops counting pairs across that boundary.

2. Two-Level Pair Counting (bpe_pair_count_b)

If millions of threads try to perform atomic insertion (atomicAdd) directly into the global hash table, a terrible contention and performance loss occurs. To prevent this, a two-level count is performed:

  1. Threads accumulate data locally into a 1024-slot local shared hash table (local_ids and local_counts) within their own workgroup, using Murmur3 hashing + quadratic probing.
  2. Once the local reduction finishes, only the workgroup totals are transferred to the global hash table in a single move. This minimizes global atomic collisions.

3. Subgroup Cooperative Compaction (bpe_finalize_compact_b)

Blelloch exclusive scan algorithms traditionally use an 8-stage up-sweep, an 8-stage down-sweep and 17 workgroupBarriers. bpe_finalize_compact_b resolves this heavy maneuver with a subgroup-cooperative scan:

  • In-subgroup scans (subgroupExclusiveAdd and subgroupAdd) run at the hardware level (barrier-free).
  • Only the subgroup totals are written to shared memory (sh_sg_excl) and scanned at the workgroup level.
  • The 256-element exclusive scan is completed with only 2 barriers instead of 17 and 4 subgroup instructions.
  • With branchless select() and WORD_START_BIT carries, symbols are scattered to the destination in a single dispatch (scatter).

4. Trie-Based Fast Inference (trie_tokenizer_chunked)

To avoid memory latency when performing text inference on the GPU (trie binary search), the following superior optimizations are applied:

  • Root LUT caching: The Trie's root node (Node 0) is queried again and again at the start of each character match. For the byte-level trie, the root node's 256 children are cached in shared memory as root_lut. This way the first-level binary search is skipped entirely, providing O(1) direct access.
  • Depth-1 Cache: The firstChild, numChildren and tokenId values of the node directly below the root (depth 1) are also cached in shared memory, zeroing out 3 global memory reads per token.
  • Branchless Search: In non-root node searches, select is used to prevent warp divergence (SIMD mask split). All threads run in lockstep (uniform execution).

Bind Group ABI

bpe_pair_count_b (4 bindings)

Counts batched pairs using the local shared hash table.

BindingTypeDetail
0storage, readsymbols: array<u32> — Input symbol array
1storage, read_writepair_counts: array<atomic<u32>> — Global pair counts
2storage, read_writepair_ids: array<atomic<u32>> — Global pair IDs (packed pair)
3storage, readstate: IterState — symbol_count, table_size, etc. iteration state

bpe_merge_reduce_b (4 bindings — Fused Merge Mark & Reduce)

Masks the best pair and sums valid symbol totals within the same workgroup via reduction.

BindingTypeDetail
0storage, read_writesymbols: array<u32> — Writes A-side merge marks
1storage, read_writevalid_mask: array<u32> — `valid
2storage, read_writeblock_sums: array<u32> — Valid symbol total of each block
3storage, readstate: IterState — symbol_a, symbol_b, new_symbol

bpe_finalize_compact_b (5 bindings — Fused Scan + Merge Apply + Scatter)

Resolves offsets with a subgroup exclusive scan and scatters symbols to the destination by compacting them.

BindingTypeDetail
0storage, readvalid_mask: array<u32> — bit 0 = valid, bit 1 = merge intent
1storage, readblock_sums: array<u32> — exclusive scanned block offsets
2storage, readinput_symbols: array<u32> — Source symbol buffer
3storage, read_writeoutput_symbols: array<u32> — Destination compacted symbol buffer
4storage, readstate: IterState

trie_tokenizer_chunked (6 bindings — Trie-based Inference)

Tokenizes the character sequence as greedy longest-match using a pre-compiled binary trie.

BindingTypeDetail
0storage, readinput: array<u32> — packed UTF-8 bytes (4 bytes per u32)
1storage, readnodes: array<u32> — Trie node info (firstChild, numChildren, tokenId)
2storage, readedges: array<u32> — Trie edge info (byte, childNodeIndex)
3storage, read_writetoken_output: array<u32> — Output token IDs
4storage, read_writechunk_counts: array<u32> — Number of tokens each chunk produced
5uniformparams: TrieParams — input_length, chunk_size, max_tokens_per_chunk

Line-by-Line Analysis

1) bpe_word_boundary — Byte Classification

wgsl
fn char_class(tok: u32) -> u32 {
    if (tok == 0x0Au) { return 4u; }
    if (tok == 0x20u) { return 2u; }
    if (tok - 0x30u <= 9u) { return 1u; }
    if (tok >= 0x80u) { return 0u; }
    if (tok - 0x61u <= 25u) { return 0u; }
    if (tok - 0x41u <= 25u) { return 0u; }
    return 3u;
}
  • Branch-less friendly subtraction trick: The (tok - base) <= range comparison covers the entire ASCII range in a single operation. By reducing SIMD mask divergence, it keeps all of the hardware's threads in lockstep.
  • All UTF-8 continuation and leading bytes above 0x80u are marked directly as class 0 (Letter). This way Turkish characters (ı, ş, ğ, ç, ö, ü, etc.) are automatically placed into the letter class.

2) bpe_finalize_compact_b — Subgroup Cooperative Scan

wgsl
let local_excl = subgroupExclusiveAdd(v);
let sg_total   = subgroupAdd(v);

if (lane == 0u) { sh_sg_excl[sg_id] = sg_total; }
workgroupBarrier();

var t: u32 = 0u;
if (lane < NUM_SUBGROUPS) { t = sh_sg_excl[lane]; }
let cross_excl = subgroupExclusiveAdd(t);

if (sg_id == 0u && lane < NUM_SUBGROUPS) { sh_sg_excl[lane] = cross_excl; }
workgroupBarrier();

let final_excl = sh_sg_excl[sg_id] + local_excl;
  • Instead of a traditional 256-element Blelloch exclusive scan, hardware-level subgroup intrinsics are used:
    1. With subgroupExclusiveAdd(v), the in-subgroup exclusive sum is resolved quickly (barrier-free).
    2. Each subgroup's total (sg_total) is dropped into shared memory (sh_sg_excl).
    3. All threads run one more exclusive subgroup scan over these 8 subgroup totals in shared memory (cross_excl).
    4. All that remains is to add its own subgroup base offset to the local offset.
  • Cost: Just 2 barriers instead of 17, plus hardware subgroup ALU units. The speedup is enormous.

3) trie_tokenizer_chunked — Root LUT and Depth-1 Cache

wgsl
let nc = min(cached_root_nc, MAX_CACHED_EDGES);
let fc = cached_root_fc;
if (lid.x < nc) {
    let sym     = edges[(fc + lid.x) * 2u] & 0xFFu;
    let d1_node = edges[(fc + lid.x) * 2u + 1u];
    root_lut[sym] = d1_node;
    d1_fc[sym]    = nodes[d1_node * 3u];
    d1_nc[sym]    = nodes[d1_node * 3u + 1u] & 0xFFFFu;
    d1_tid[sym]   = nodes[d1_node * 3u + 2u];
}
workgroupBarrier();
  • Each token match starts from the Trie root node (Node 0).
  • All of the root's branches (root_lut) and the metadata of the depth-1 nodes (d1_fc, d1_nc, d1_tid) are cached in shared memory.
  • In the search loop, when cn == 0u (the transition from depth 0 to 1), the search drops to O(1) speed:
    wgsl
    nn = root_lut[bv];
  • In the transition from depth 1 to 2, the shared cache is read instead of global memory pointer-chasing:
    wgsl
    nn = find_child_global(d1_fc[rb], d1_nc[rb], bv);
    These two steps completely eliminate the global memory bottleneck at the most frequently triggered search levels.

Code Review

Finding 1: bpe_pair_count_b Local Table Sizing and Occupancy Relationship

RiskDescription
🟡 perfLOCAL_TABLE_SIZE was set to 1024 slots (8 KB) in shared memory. When making this size 2048 slots (16 KB) was tried, it was found that on Apple M GPUs the occupancy ratio dropped by half (4 WG/SM → 2 WG/SM) and, due to the loss of latency hiding, the step time slowed down by 70%. The current balance of 1024 slots (25% fill + Murmur3 + quadratic probing) gives the most optimal hardware occupancy profile.

Finding 2: trie_tokenizer_compact Coalesced Memory Writes

RiskDescription
🟢 ResolvedThe old design's 1 thread = 1 chunk approach caused consecutive threads to write with stride-512 jumps over memory (broken coalescing). In the trie_tokenizer_compact kernel, however, a 1 workgroup = 1 chunk mapping was made; 256 threads cooperatively copy consecutive write indices (compact_output[db + i]), maximizing coalesced memory access success.

Finding 3: Eliminating the PCIe Bottleneck (trie_prefix_sum)

RiskDescription
🟢 ResolvedThe prefix sum operation required to compact the tokenizer outputs was, in the old structure, pulled to the CPU via mapAsync, scanned on the CPU and loaded back. trie_prefix_sum was moved to the GPU as a single-thread (workgroup_size(1)) kernel, zeroing the CPU roundtrip so that 2 PCIe transfer latencies and 1 GPU fence wait are completely eliminated.

Quick Checklist

Test ScenarioStatus
Are word boundaries scanned correctly?symbols[id] | WORD_START_BIT
What if subgroup operations are not supported by the hardware?⚠️ Requires a training fallback mechanism on Safari/Webkit
Is warp divergence prevented in the trie search?select()-based branchless lower_bound
Is the local hash table modulo power-of-2?LOCAL_TABLE_MASK = 1023
Can the Blelloch scan climb up to 1B symbols?✅ Hierarchical 3-pass scan
Is the root node LUT cache safe?✅ 256 byte ASCII LUT

Host Architecture and Coordination (trie.js & tokenizer.js)

The superior speed of the BPE tokenizer comes not only from the GPU shaders, but also from the intelligent memory management and asynchronous execution architecture on the host (JS) side.

1. BFS-Based Cache-Friendly Trie Compilation (compileVocabToTrie)

After the in-memory Trie tree is built from the raw word list (byte sequences), it is serialized as a flat binary (v3 format).

  • BFS Ordering (Breadth-First Search): When the tree is flattened, BFS order is followed. This way a node's children are placed consecutively in memory. The GPU's L1/L2 caches reach maximum hit rate when accessing these consecutive indices.
  • Sorted Children: Child edges are sorted by symbol byte value from smallest to largest. This sorting enables the find_child_global helper function on the GPU side to perform a uniform binary search in non-root node searches.

2. Persistent Buffer Pooling

In WebGPU, building and destroying a GPUBuffer from scratch (garbage collection) on every tokenization call, plus the GPU allocation overhead, is a major source of latency. TrieTokenizer solves this problem with an amortized O(1) buffer pool:

  • Persistent buffers (#inputBuf, #tokenBuf, #countsBuf, #offsetsBuf, #totalBuf, #compactBuf) are built once based on the size of the input text and re-used across all subsequent encode calls (re-use).
  • Only when the newly arriving text is larger than the current pool capacity is the pool grown by 1.5x and reallocated.

3. Dynamic Slicing (Multi-pass Slicing)

WebGPU has a hardware maxBufferSize limit. If a very large text is tried to be loaded onto the GPU all at once, a memory overflow occurs.

  • TrieTokenizer analyzes the incoming text size against GPU limits.
  • If the input exceeds the limits, it splits the text into safe slices (sliceSize = Math.floor(maxInputPerPass / chunkSize) * chunkSize) and processes it safely and stably with consecutive dispatches.

4. Single Submit and Zero-CPU Latency (Single Submit Command Encoding)

One of the most critical optimizations is in the asynchronous command communication between CPU and GPU:

  • The 3 separate passes (trie_tokenizer_chunkedtrie_prefix_sumtrie_tokenizer_compact) are combined into a single CommandEncoder and sent to the GPU with a single submit (device.queue.submit).
  • This way the CPU never waits for the GPU during intermediate steps (zero CPU-GPU roundtrip). The GPU pushes the commands through the pipeline uninterrupted.
  • Smart DMA Copy: Before compaction, the full size of the output is not known (because the number of tokens each chunk produces is variable). To avoid locking up CPU-GPU synchronization, the 4-byte total_tokens count is computed on the GPU first and pulled to the host near-instantly via mapAsync (because the work is already finished). Then a full-size DMA copyBufferToBuffer is triggered to read only the compacted, real token sequence. Memory waste is zeroed out.

Next

To return to the index again: index.md.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.