llm.istanbul·Study
TR EN
Workbench →

embed_lookup and embed_backward — Token ID → Vector

File: 01_embedding.wgsl Pipeline step: 0 — The first step of the forward pass, turning tokens into numeric vectors.

3 kernels: embed_lookup, embed_backward, embed_lookup_w16 (mixed precision variant).


Wait, what is this?

Say you've got the word "cat". A computer doesn't understand "cat" — it only understands numbers. The tokenizer handed "cat" an ID: 4213. But 4213 on its own is a dumb number — the fact that it's smaller than 4214 ("dog", say) doesn't mean "cat < dog"; the IDs are arbitrary.

That's where embedding comes in: it turns each token ID into a vector of numbers that actually carries meaning. 4213[0.2, -1.1, 0.7, …] (768 numbers, say). That vector is everything the model knows about "cat", squeezed into one row. As training goes on, the vectors for "cat" and "dog" drift closer together, while "cat" and "truck" drift apart.

So what does the kernel do? It's really just a giant lookup table. Picture a [vocab × 768] table where each row is one token's vector. embed_lookup literally pulls table[token_id] — array indexing, like grabbing a value out of a hash map. The only twist: it does this for thousands of tokens at once, on the GPU.

The backward side runs in reverse: we write the "nudge this token's vector in this direction to lower the loss" signal back into the table. One subtlety — if the same token shows up 10 times in a sentence, 10 threads try to write to the same row at once, so the writes have to be atomic (more below).


What Does It Do?

Forward: embed_lookup

The model's first forward step. For each token ID, it reads a row from the embedding table:

out[s, d] = table[tokens[s], d]

If tokens[s] >= vocab_size (corrupt input) → out[s, d] = 0.

Backward: embed_backward

Scatter-add gradient. Each thread, for one output position, adds into the gradient row of the source token. Because the same token can appear at multiple positions, an atomic CAS-add is required:

grad_table[tokens[s], d] += grad_out[s, d]   ∀ (s, d)

Mixed-precision: embed_lookup_w16

Reads the table from f16 storage, output is f32. There is no f16 in backward — the gradient is always fp32 (numerical stability).


Mathematical Definition

Forward

out[s, d] = table[tokens[s], d]    if tokens[s] < vocab_size
            0                        otherwise

Backward

∂L/∂table[t, d] = Σ_s [tokens[s] == t] · ∂L/∂out[s, d]

This is literally a "sparse scatter-add" — if token t occurs at k distinct positions in total, then k distinct contributions are written to that row. Threads writing to the same slot produce a race condition → atomics are needed.


Bind Group ABI

embed_lookup (4 bindings)

BindingTypeDetail
0storage, readtokens: array<u32>[seq_len]
1storage, readtable: array<f32>[vocab × d_model] row-major
2storage, read_writeout: array<f32>[seq × d_model]
3uniformdims: vec4<u32>(seq_len, d_model, vocab_size, _)

embed_backward (4 bindings)

BindingTypeDetail
0storage, readtokens: array<u32>
1storage, readgrad_out: array<f32> — gradient from norm/upstream
2storage, read_writegrad_table: array<atomic<u32>>bit-cast f32
3uniformdims: vec4<u32>

Caution: The grad_table declaration is array<atomic<u32>>. The data is actually f32, but WGSL has no atomic<f32> — hence the bit-cast trick (details below).

embed_lookup_w16 (4 bindings, _w16 suffixed)

BindingTypeDetail
0storage, readtokens_w16: array<u32>
1storage, readtable_w16: array<f16> ← f16 storage
2storage, read_writeout_w16: array<f32> ← output is still f32
3uniformdims_w16: vec4<u32>

Dispatch Shape

workgroup_size: 256
total threads:  ceil(seq_len × d_model / 256) workgroups × 256

One thread = one (s, d) pair. That is, one thread for each element of the embedding output.

Example (seq=512, d=768): 393K threads, 1536 WG.


Line by Line — embed_lookup

1) Bindings + entry

wgsl
@compute @workgroup_size(256, 1, 1)
fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>,
                @builtin(num_workgroups) nwg: vec3<u32>) {
    let i = flat_id(gid, nwg);
    let seq_len = dims.x;
    let d_model = dims.y;
    let vocab_size = dims.z;
    let total = seq_len * d_model;
    if (i >= total) { return; }
  • i is the global flat index, in [0, seq×d_model)
  • dims.x/y/z are deconstructed — we received it as vec4<u32> because it falls under 16-byte alignment
  • Bounds check — empty threads appear in the last WG

2) (s, d) decode

wgsl
let s = i / d_model;
let d = i % d_model;

Flat index → 2D (sequence position, dimension within embedding).

Performance note: % d_model is a divide-and-modulo — the compiler computes both at once (udiv instruction). If d_model = 2^k, it optimizes into a bit-shift (not for 768, but it would for 1024).

3) Token lookup + bounds

wgsl
let t = tokens[s];
if (t >= vocab_size) {
    out[i] = 0.0;
    return;
}
  • If the token is corrupt (e.g. a padding token not in our vocab) → zero out the output.
  • This is defensive — it should never trigger on a healthy corpus, but on a vocab fingerprint mismatch there will be no silent corruption.

4) Table lookup

wgsl
out[i] = table[t * d_model + d];

Row-major: table[t * d_model + d] is the d-th dim of token t.

Memory pattern: Adjacent threads write to out[i] and out[i+1] → coalesced. But the table[t*d + d] pattern depends on t → if adjacent threads are on the same s (i.e. d varies) it's a coalesced read; but if threads with different s share the same d, it's a scattered read. In practice, the WG=256 threads take consecutive i in the range i = base..base+255 → generally the same s, different d → coalesced.


Line by Line — embed_backward (atomic CAS-add)

1) The first part is the same as embed_lookup

wgsl
let i = flat_id(gid, nwg);
// ... bounds check, decode, token lookup
let val = grad_out[i];
if (val == 0.0 || !is_finite(val)) { return; }
let dst_idx = t * d_model + d;

With is_finite, skip NaN/Inf gradients — this prevents overflow propagation.

2) CAS-loop f32 atomic add

wgsl
var old_bits = atomicLoad(&grad_table[dst_idx]);
loop {
    let new_bits = bitcast<u32>(bitcast<f32>(old_bits) + val);
    let res = atomicCompareExchangeWeak(&grad_table[dst_idx], old_bits, new_bits);
    if (res.exchanged) { break; }
    old_bits = res.old_value;
}

Why so complicated? WGSL's atomic functions only operate on i32 and u32. There is no float atomic add. The solution:

  1. Bit-cast the f32 as u32 (same 32-bit pattern, different interpretation)
  2. atomicLoadold_bits (the current value as u32)
  3. Compute new_bits = bits(old_as_f32 + val) — i.e. u32 → f32 cast → add → f32 → u32 cast
  4. atomicCompareExchangeWeak:
    • if the slot is still old_bits → write new_bits, exchanged = true
    • if another thread slipped in and changed it → exchanged = false, the real value is returned in res.old_value
  5. if exchanged, break out of the loop; otherwise retry with the new old_bits

This is a lock-free retry pattern. As long as there is no race, it finishes in a single iteration; if there is a race with k threads writing to the same slot, it degrades to O(k²) total attempts.

How much contention in practice?

  • In a Turkish corpus, the most frequent token "▁ve" is at ~2% frequency. So in a 512-seq it occurs ~10 times.
  • Each occurrence is one thread → 10 threads writing to the same row.
  • 10² = 100 failed CAS in the worst case, but since they wait in parallel the practical latency is ~10 cycles/atomic.
  • Not a significant bottleneck.

3) Why bitcast<f32>(0) is actually 0.0

In WGSL, bitcast<f32>(0u)0.0f (a zero-init grad_table means zero f32). fill_zero is used for init — zero f32 = zero u32, the bit-pattern is identical.


embed_lookup_w16 Difference

wgsl
out_w16[i] = f32(table_w16[t * d_model + d]);

The only difference: table_w16 is array<f16> and there's an f32 cast on read. The table is half the size (12.6MB → 6.3MB for 16K×384), saving bandwidth.

Why is there no backward for the _w16 version?

The backward gradient is always f32 (the mixed precision standard). If you added the f32 grad onto an f16 weight and wrote it back, you'd get precision loss. Instead:

  1. Keep an f32 master copy
  2. Make an f16 mirror for the forward pass (cast_f32_to_f16 kernel)
  3. The backward writes to the f32 master

That's why embed_lookup and embed_lookup_w16 are separate; embed_backward is single.


WGSL-Specific Notes

1. No atomic<f32>

Not in the spec. The CAS-bit-cast pattern is the known workaround. Vulkan and D3D12 have native atomicAdd, but it's absent for WebGPU portability. Apple Metal has native atomic_fetch_add (f32), but WebGPU's smallest common denominator is this.

2. bitcast<T>(x) — non-mutating

bitcast<u32>(1.0f)0x3f800000. Same 32 bits, different interpretation. Cost: none (a no-op on CPU/GPU).

3. atomicLoad and atomicCompareExchangeWeak — pointer required

wgsl
atomicLoad(&grad_table[dst_idx])

A pointer-to-storage is legal here because the first argument is special. Generally, passing a WGSL pointer into a function is forbidden (uniform analysis). Atomic functions are an exception.

4. array<atomic<u32>> — type level wrap

The storage buffer must be marked atomic at declaration level:

wgsl
@group(0) @binding(2) var<storage, read_write> grad_table: array<atomic<u32>>;

If you write this as array<u32>, atomicLoad/CompareExchangeWeak will give a compile error.

5. loop { ... break; } — instead of a for-construct

In WGSL, loop {} is an infinite loop, exited with break. There is no for (;;) or while(true). This is idiomatic usage in a CAS retry.


Code Review

Finding 1: In the forward pass, out[i] = 0 is silent for OOB tokens

RiskExplanation
🟡 mediumThe tokens[s] >= vocab_size case does a silent zero-fill. If the corpus is corrupt, you can't tell which position it occurred at. Vocab fingerprint validation is done beforehand (engine.js:_checkVocabMatch), but it's good as a runtime safety net.

Decision: Keep the defensive zeroing. The preflight check catches vocab mismatches.

Finding 2: is_finite filter only in backward

RiskExplanation
🟢 noneIn the forward pass, table[t][d] already comes from f32 storage, is zero at init, and is later written by AdamW — finite is guaranteed. In the backward pass, grad_out[i] comes from upstream, with a higher risk of NaN/Inf. Correct design.

Finding 3: No atomic CAS failure metric

RiskExplanation
🟢 none, but an observationThere is no debug counter to learn the actual CAS retry count. It could be hard to catch a performance regression; but since there's no noticeable slowdown in practice, skip it.

Quick Checklist

Test ScenarioStatus
Token ID > vocab → is output 0?✅ code check
Does the backward gradient skip NaN?is_finite
Is there a coalesced read for vocab=16384, d=384?✅ memory pattern
When the same token occurs 10 times, is the sum correct?⚠ no unit test
Does embed_lookup_w16 produce the same result as embed_lookup?⚠ no regression test
Does it avoid crashing for an OOB token (vocab_size + 1)?

Next

02_norm.md — RMSNorm forward + backward. L2-normalizes the embedding output.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.