cast_f32_to_f16 and cast_f16_to_f32 — Mixed Precision Conversion
File: 08_cast.wgsl Pipeline step: Runs at the optimizer step output — master fp32 → forward fp16 mirror.
Two kernels:
cast_f32_to_f16— saturated clamp + cast (into the fp16 mirror of the forward weights)cast_f16_to_f32— plain cast
Wait, what is this?
Think of it like this: you've got a RAW photo — huge file, but every detail is intact. While you're editing you always work off that RAW: lossless, solid. But when it's time to put the photo on the web you don't ship that giant file, you export a small JPEG from it. The JPEG loses a bit of detail, but it's half the size and loads way faster. This file does exactly that JPEG-exporting job.
Here RAW = f32 (32-bit, 4 bytes), JPEG = f16 (16-bit, 2 bytes). While the model trains, the "master" copy of the weights always lives in fp32, because you need the precision to accumulate tiny gradient updates. But reading those giant fp32 values on every forward pass chokes the memory bandwidth. So alongside it we keep an fp16 "mirror" — half the size, twice as fast to read. After every optimizer step, cast_f32_to_f16 refreshes that fp16 mirror from the fp32 master.
The cast itself is almost insultingly simple — take a number, re-write it so it fits in 2 bytes. It's no different from dropping a double down to a float. The one extra subtlety: fp16 has a ceiling value it can hold (65504). Cast a number above that and it turns into infinity (Inf) — so first we clamp the value down to that ceiling, much like stopping an overexposed photo from blowing out to pure white.
The reverse direction (cast_f16_to_f32) is totally carefree — going from a small box to a big box loses nothing, like reopening the JPEG. The genuinely interesting part is how this trivial copy turns into a performance headache once it's called 110 separate times — we get to that below.
What Does It Do?
In the mixed precision pipeline the master weights are kept as fp32 (numerical stability — Adam moments + small gradient updates). For the forward pass an fp16 mirror is used (memory bandwidth savings).
Optimizer step:
- AdamW updates the master weights in fp32
cast_f32_to_f16runs for each parameter tensor → updates the fp16 mirror- The next forward pass reads from the fp16 mirror
cast_f16_to_f32 is symmetric — fp16 → fp32. It is actually not used in our hot path (the forward kernels do an inline cast during tile-load); it stays around as a standalone helper.
Why Saturated Clamp (±65504)?
f16 max = 65504 (= 2^15 × (2 - 2^-10)). A larger value turns into ±Inf during the cast.
In healthy training the weights stay |w| < 5 — the clamp never triggers. But:
- Gradient explosion (rare but it happens)
- Adam state corruption
- Numerical instability spike
→ a large value can appear in the master fp32 → naive cast → fp16 Inf → forward NaN propagation.
clamp(±65504) is a defensive net. Cost: 1 instruction per element. It prevents the forward pass from being contaminated with NaN — worth it.
Bind Group ABI
cast_f32_to_f16 (3 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read | src: array<f32> |
| 1 | storage, read_write | dst: array<f16> |
| 2 | uniform | n: u32 |
cast_f16_to_f32
Type reversed; the rest is the same.
Dispatch Shape
workgroup_size: 256
threads: ceil(n / 256) workgroups × 2561 thread = 1 element. Memory-bound, no shared memory, no barriers.
Line by Line
cast_f32_to_f16
@compute @workgroup_size(256, 1, 1)
fn cast_f32_to_f16(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) nwg: vec3<u32>) {
let i = flat_id(gid, nwg);
if (i >= n) { return; }
dst[i] = f16(clamp(src[i], -65504.0, 65504.0));
}flat_idfor 2D dispatch fallbackclampto f16 rangef16(...)cast — WGSL spec built-in (only available withenable f16;)
cast_f16_to_f32
dst[i] = f32(src[i]);f16 → f32 is always safe (lossless). No clamp needed.
cast_f32_to_f16 Hot Spot — 110 Calls and Its Optimization
According to the old profiling output, a separate dispatch was done for each parameter tensor:
cast_f32_to_f16 (×110) 8.585 ms total → ~1.7% of step110 = 12 layers × 9 weights per layer (att Q/K/V/O + FFN gate/up/down + 2 norms) + outliers (embed, lm_head, final norm).
Even though each dispatch is small (~78 µs), the accumulated dispatch overhead (~110 dispatches × ~50 µs ≈ 5.5 ms) introduced a serious delay.
Completed Optimization — Fused F16 Mirroring:
To zero out this dispatch load, the cast_f32_to_f16 steps have been fused into the adamw_update_f16 (or adamw_8bit_update_f16) kernels of the optimizer stage. While the adamw_update_f16 update kernel updates the fp32 master weights, it also writes the forward fp16 mirror directly (into the dst_w16 buffer) within the same dispatch/pass. This way the standalone cast_f32_to_f16 loop and the 110 separate dispatch overheads have been eliminated entirely.
WGSL-Specific Notes
1. enable f16; required
This file is loaded only when the shader-f16 feature is present. In engine.js:
if (!adapter.features.has('shader-f16')) {
throw new Error('...');
}enable f16; is injected into the shared preamble → the f16 type is usable in every kernel. Safari has no shader-f16 → this file is not loaded, mixed precision mode disabled.
2. clamp(NaN, -c, c)
The behavior is unspecified in the WGSL spec. On Apple Metal it is in practice c (the nearest bound). In practice it is not an issue — gradient clip and z_loss already prevent NaN.
3. f16 storage layout
array<f16> is 16-bit storage — fp16 max 65504. Our weights are below ~5, abundant headroom.
WebGPU spec: f16 storage is packed (2 bytes per element, no padding). On Apple Metal hardware it is identical to half — there are native f16x4 instructions. There is also hardware support on NVIDIA/Intel desktop.
Code Review
Finding 1: Clamp every element — wasted bandwidth?
| Risk | Description |
|---|---|
| 🟡 minor | In healthy training no clamp ever triggers. Still, every element does 2 comparisons + 2 selects. Marginal CPU overhead. In practice, since it is memory-bound, the extra ALU load is ignored. |
Finding 2: Standalone vs fused
| Risk | Description |
|---|---|
| 🟢 Resolved | The standalone cast step required 110 dispatches/step. This operation has been fused into adamw_update_f16, zeroing out the extra dispatch load and yielding a ~1.5-2.0% step time improvement. |
Finding 3: Is cast_f16_to_f32 used?
| Risk | Description |
|---|---|
| 🟢 none | Currently not called — the forward kernels do an inline f32(weight_w16) at tile-load. The standalone kernel stays around for future use. No bandwidth cost. |
Quick Checklist
| Test Scenario | Status |
|---|---|
Does clamp(weight = 1000) clamp correctly? | ✅ formula ±65504 |
Is f16(weight) → f32 lossless? | ✅ widening cast |
Is f16(weight = 1.0) → f32 = 1.0 exact? | ✅ powers of 2 exact |
| Is the f16 mirror periodically synced with the master? | ✅ every step after adamw_update |
| Does it not load on Safari? | ✅ shader-f16 feature gate |
Next
09_backward_ffn.md — FFN activation backwards: GeLU + SwiGLU.