embed_lookup and embed_backward — Token ID → Vector
File: 01_embedding.wgsl Pipeline step: 0 — The first step of the forward pass, turning tokens into numeric vectors.
3 kernels: embed_lookup, embed_backward, embed_lookup_w16 (mixed precision variant).
Wait, what is this?
Say you've got the word "cat". A computer doesn't understand "cat" — it only understands numbers. The tokenizer handed "cat" an ID: 4213. But 4213 on its own is a dumb number — the fact that it's smaller than 4214 ("dog", say) doesn't mean "cat < dog"; the IDs are arbitrary.
That's where embedding comes in: it turns each token ID into a vector of numbers that actually carries meaning. 4213 → [0.2, -1.1, 0.7, …] (768 numbers, say). That vector is everything the model knows about "cat", squeezed into one row. As training goes on, the vectors for "cat" and "dog" drift closer together, while "cat" and "truck" drift apart.
So what does the kernel do? It's really just a giant lookup table. Picture a [vocab × 768] table where each row is one token's vector. embed_lookup literally pulls table[token_id] — array indexing, like grabbing a value out of a hash map. The only twist: it does this for thousands of tokens at once, on the GPU.
The backward side runs in reverse: we write the "nudge this token's vector in this direction to lower the loss" signal back into the table. One subtlety — if the same token shows up 10 times in a sentence, 10 threads try to write to the same row at once, so the writes have to be atomic (more below).
What Does It Do?
Forward: embed_lookup
The model's first forward step. For each token ID, it reads a row from the embedding table:
out[s, d] = table[tokens[s], d]If tokens[s] >= vocab_size (corrupt input) → out[s, d] = 0.
Backward: embed_backward
Scatter-add gradient. Each thread, for one output position, adds into the gradient row of the source token. Because the same token can appear at multiple positions, an atomic CAS-add is required:
grad_table[tokens[s], d] += grad_out[s, d] ∀ (s, d)Mixed-precision: embed_lookup_w16
Reads the table from f16 storage, output is f32. There is no f16 in backward — the gradient is always fp32 (numerical stability).
Mathematical Definition
Forward
out[s, d] = table[tokens[s], d] if tokens[s] < vocab_size
0 otherwiseBackward
∂L/∂table[t, d] = Σ_s [tokens[s] == t] · ∂L/∂out[s, d]This is literally a "sparse scatter-add" — if token t occurs at k distinct positions in total, then k distinct contributions are written to that row. Threads writing to the same slot produce a race condition → atomics are needed.
Bind Group ABI
embed_lookup (4 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | tokens: array<u32> — [seq_len] |
| 1 | storage, read | table: array<f32> — [vocab × d_model] row-major |
| 2 | storage, read_write | out: array<f32> — [seq × d_model] |
| 3 | uniform | dims: vec4<u32> — (seq_len, d_model, vocab_size, _) |
embed_backward (4 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | tokens: array<u32> |
| 1 | storage, read | grad_out: array<f32> — gradient from norm/upstream |
| 2 | storage, read_write | grad_table: array<atomic<u32>> — bit-cast f32 |
| 3 | uniform | dims: vec4<u32> |
Caution: The
grad_tabledeclaration isarray<atomic<u32>>. The data is actually f32, but WGSL has noatomic<f32>— hence the bit-cast trick (details below).
embed_lookup_w16 (4 bindings, _w16 suffixed)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | tokens_w16: array<u32> |
| 1 | storage, read | table_w16: array<f16> ← f16 storage |
| 2 | storage, read_write | out_w16: array<f32> ← output is still f32 |
| 3 | uniform | dims_w16: vec4<u32> |
Dispatch Shape
workgroup_size: 256
total threads: ceil(seq_len × d_model / 256) workgroups × 256One thread = one (s, d) pair. That is, one thread for each element of the embedding output.
Example (seq=512, d=768): 393K threads, 1536 WG.
Line by Line — embed_lookup
1) Bindings + entry
@compute @workgroup_size(256, 1, 1)
fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) nwg: vec3<u32>) {
let i = flat_id(gid, nwg);
let seq_len = dims.x;
let d_model = dims.y;
let vocab_size = dims.z;
let total = seq_len * d_model;
if (i >= total) { return; }iis the global flat index, in[0, seq×d_model)dims.x/y/zare deconstructed — we received it asvec4<u32>because it falls under 16-byte alignment- Bounds check — empty threads appear in the last WG
2) (s, d) decode
let s = i / d_model;
let d = i % d_model;Flat index → 2D (sequence position, dimension within embedding).
Performance note: % d_model is a divide-and-modulo — the compiler computes both at once (udiv instruction). If d_model = 2^k, it optimizes into a bit-shift (not for 768, but it would for 1024).
3) Token lookup + bounds
let t = tokens[s];
if (t >= vocab_size) {
out[i] = 0.0;
return;
}- If the token is corrupt (e.g. a padding token not in our vocab) → zero out the output.
- This is defensive — it should never trigger on a healthy corpus, but on a vocab fingerprint mismatch there will be no silent corruption.
4) Table lookup
out[i] = table[t * d_model + d];Row-major: table[t * d_model + d] is the d-th dim of token t.
Memory pattern: Adjacent threads write to out[i] and out[i+1] → coalesced. But the table[t*d + d] pattern depends on t → if adjacent threads are on the same s (i.e. d varies) it's a coalesced read; but if threads with different s share the same d, it's a scattered read. In practice, the WG=256 threads take consecutive i in the range i = base..base+255 → generally the same s, different d → coalesced.
Line by Line — embed_backward (atomic CAS-add)
1) The first part is the same as embed_lookup
let i = flat_id(gid, nwg);
// ... bounds check, decode, token lookup
let val = grad_out[i];
if (val == 0.0 || !is_finite(val)) { return; }
let dst_idx = t * d_model + d;With is_finite, skip NaN/Inf gradients — this prevents overflow propagation.
2) CAS-loop f32 atomic add
var old_bits = atomicLoad(&grad_table[dst_idx]);
loop {
let new_bits = bitcast<u32>(bitcast<f32>(old_bits) + val);
let res = atomicCompareExchangeWeak(&grad_table[dst_idx], old_bits, new_bits);
if (res.exchanged) { break; }
old_bits = res.old_value;
}Why so complicated? WGSL's atomic functions only operate on i32 and u32. There is no float atomic add. The solution:
- Bit-cast the f32 as u32 (same 32-bit pattern, different interpretation)
atomicLoad→old_bits(the current value as u32)- Compute
new_bits = bits(old_as_f32 + val)— i.e. u32 → f32 cast → add → f32 → u32 cast atomicCompareExchangeWeak:- if the slot is still
old_bits→ writenew_bits,exchanged = true - if another thread slipped in and changed it →
exchanged = false, the real value is returned inres.old_value
- if the slot is still
- if
exchanged, break out of the loop; otherwise retry with the newold_bits
This is a lock-free retry pattern. As long as there is no race, it finishes in a single iteration; if there is a race with k threads writing to the same slot, it degrades to O(k²) total attempts.
How much contention in practice?
- In a Turkish corpus, the most frequent token "▁ve" is at ~2% frequency. So in a 512-seq it occurs ~10 times.
- Each occurrence is one thread → 10 threads writing to the same row.
- 10² = 100 failed CAS in the worst case, but since they wait in parallel the practical latency is ~10 cycles/atomic.
- Not a significant bottleneck.
3) Why bitcast<f32>(0) is actually 0.0
In WGSL, bitcast<f32>(0u) → 0.0f (a zero-init grad_table means zero f32). fill_zero is used for init — zero f32 = zero u32, the bit-pattern is identical.
embed_lookup_w16 Difference
out_w16[i] = f32(table_w16[t * d_model + d]);The only difference: table_w16 is array<f16> and there's an f32 cast on read. The table is half the size (12.6MB → 6.3MB for 16K×384), saving bandwidth.
Why is there no backward for the _w16 version?
The backward gradient is always f32 (the mixed precision standard). If you added the f32 grad onto an f16 weight and wrote it back, you'd get precision loss. Instead:
- Keep an f32 master copy
- Make an f16 mirror for the forward pass (cast_f32_to_f16 kernel)
- The backward writes to the f32 master
That's why embed_lookup and embed_lookup_w16 are separate; embed_backward is single.
WGSL-Specific Notes
1. No atomic<f32>
Not in the spec. The CAS-bit-cast pattern is the known workaround. Vulkan and D3D12 have native atomicAdd, but it's absent for WebGPU portability. Apple Metal has native atomic_fetch_add (f32), but WebGPU's smallest common denominator is this.
2. bitcast<T>(x) — non-mutating
bitcast<u32>(1.0f) → 0x3f800000. Same 32 bits, different interpretation. Cost: none (a no-op on CPU/GPU).
3. atomicLoad and atomicCompareExchangeWeak — pointer required
atomicLoad(&grad_table[dst_idx])A pointer-to-storage is legal here because the first argument is special. Generally, passing a WGSL pointer into a function is forbidden (uniform analysis). Atomic functions are an exception.
4. array<atomic<u32>> — type level wrap
The storage buffer must be marked atomic at declaration level:
@group(0) @binding(2) var<storage, read_write> grad_table: array<atomic<u32>>;If you write this as array<u32>, atomicLoad/CompareExchangeWeak will give a compile error.
5. loop { ... break; } — instead of a for-construct
In WGSL, loop {} is an infinite loop, exited with break. There is no for (;;) or while(true). This is idiomatic usage in a CAS retry.
Code Review
Finding 1: In the forward pass, out[i] = 0 is silent for OOB tokens
| Risk | Explanation |
|---|---|
| 🟡 medium | The tokens[s] >= vocab_size case does a silent zero-fill. If the corpus is corrupt, you can't tell which position it occurred at. Vocab fingerprint validation is done beforehand (engine.js:_checkVocabMatch), but it's good as a runtime safety net. |
Decision: Keep the defensive zeroing. The preflight check catches vocab mismatches.
Finding 2: is_finite filter only in backward
| Risk | Explanation |
|---|---|
| 🟢 none | In the forward pass, table[t][d] already comes from f32 storage, is zero at init, and is later written by AdamW — finite is guaranteed. In the backward pass, grad_out[i] comes from upstream, with a higher risk of NaN/Inf. Correct design. |
Finding 3: No atomic CAS failure metric
| Risk | Explanation |
|---|---|
| 🟢 none, but an observation | There is no debug counter to learn the actual CAS retry count. It could be hard to catch a performance regression; but since there's no noticeable slowdown in practice, skip it. |
Quick Checklist
| Test Scenario | Status |
|---|---|
| Token ID > vocab → is output 0? | ✅ code check |
| Does the backward gradient skip NaN? | ✅ is_finite |
Is there a coalesced read for vocab=16384, d=384? | ✅ memory pattern |
| When the same token occurs 10 times, is the sum correct? | ⚠ no unit test |
Does embed_lookup_w16 produce the same result as embed_lookup? | ⚠ no regression test |
| Does it avoid crashing for an OOB token (vocab_size + 1)? | ✅ |
Next
02_norm.md — RMSNorm forward + backward. L2-normalizes the embedding output.