llm.istanbul·Study
TR EN
Workbench →

cast_f32_to_f16 and cast_f16_to_f32 — Mixed Precision Conversion

File: 08_cast.wgsl Pipeline step: Runs at the optimizer step output — master fp32 → forward fp16 mirror.

Two kernels:

  • cast_f32_to_f16 — saturated clamp + cast (into the fp16 mirror of the forward weights)
  • cast_f16_to_f32 — plain cast

Wait, what is this?

Think of it like this: you've got a RAW photo — huge file, but every detail is intact. While you're editing you always work off that RAW: lossless, solid. But when it's time to put the photo on the web you don't ship that giant file, you export a small JPEG from it. The JPEG loses a bit of detail, but it's half the size and loads way faster. This file does exactly that JPEG-exporting job.

Here RAW = f32 (32-bit, 4 bytes), JPEG = f16 (16-bit, 2 bytes). While the model trains, the "master" copy of the weights always lives in fp32, because you need the precision to accumulate tiny gradient updates. But reading those giant fp32 values on every forward pass chokes the memory bandwidth. So alongside it we keep an fp16 "mirror" — half the size, twice as fast to read. After every optimizer step, cast_f32_to_f16 refreshes that fp16 mirror from the fp32 master.

The cast itself is almost insultingly simple — take a number, re-write it so it fits in 2 bytes. It's no different from dropping a double down to a float. The one extra subtlety: fp16 has a ceiling value it can hold (65504). Cast a number above that and it turns into infinity (Inf) — so first we clamp the value down to that ceiling, much like stopping an overexposed photo from blowing out to pure white.

The reverse direction (cast_f16_to_f32) is totally carefree — going from a small box to a big box loses nothing, like reopening the JPEG. The genuinely interesting part is how this trivial copy turns into a performance headache once it's called 110 separate times — we get to that below.


What Does It Do?

In the mixed precision pipeline the master weights are kept as fp32 (numerical stability — Adam moments + small gradient updates). For the forward pass an fp16 mirror is used (memory bandwidth savings).

Optimizer step:

  1. AdamW updates the master weights in fp32
  2. cast_f32_to_f16 runs for each parameter tensor → updates the fp16 mirror
  3. The next forward pass reads from the fp16 mirror

cast_f16_to_f32 is symmetric — fp16 → fp32. It is actually not used in our hot path (the forward kernels do an inline cast during tile-load); it stays around as a standalone helper.


Why Saturated Clamp (±65504)?

f16 max = 65504 (= 2^15 × (2 - 2^-10)). A larger value turns into ±Inf during the cast.

In healthy training the weights stay |w| < 5 — the clamp never triggers. But:

  • Gradient explosion (rare but it happens)
  • Adam state corruption
  • Numerical instability spike

→ a large value can appear in the master fp32 → naive cast → fp16 Inf → forward NaN propagation.

clamp(±65504) is a defensive net. Cost: 1 instruction per element. It prevents the forward pass from being contaminated with NaN — worth it.


Bind Group ABI

cast_f32_to_f16 (3 bindings)

BindingTypeDetail
0storage, readsrc: array<f32>
1storage, read_writedst: array<f16>
2uniformn: u32

cast_f16_to_f32

Type reversed; the rest is the same.


Dispatch Shape

workgroup_size: 256
threads:        ceil(n / 256) workgroups × 256

1 thread = 1 element. Memory-bound, no shared memory, no barriers.


Line by Line

cast_f32_to_f16

wgsl
@compute @workgroup_size(256, 1, 1)
fn cast_f32_to_f16(@builtin(global_invocation_id) gid: vec3<u32>,
                   @builtin(num_workgroups) nwg: vec3<u32>) {
    let i = flat_id(gid, nwg);
    if (i >= n) { return; }
    dst[i] = f16(clamp(src[i], -65504.0, 65504.0));
}
  • flat_id for 2D dispatch fallback
  • clamp to f16 range
  • f16(...) cast — WGSL spec built-in (only available with enable f16;)

cast_f16_to_f32

wgsl
dst[i] = f32(src[i]);

f16 → f32 is always safe (lossless). No clamp needed.


cast_f32_to_f16 Hot Spot — 110 Calls and Its Optimization

According to the old profiling output, a separate dispatch was done for each parameter tensor:

cast_f32_to_f16 (×110)  8.585 ms total  →  ~1.7% of step

110 = 12 layers × 9 weights per layer (att Q/K/V/O + FFN gate/up/down + 2 norms) + outliers (embed, lm_head, final norm).

Even though each dispatch is small (~78 µs), the accumulated dispatch overhead (~110 dispatches × ~50 µs ≈ 5.5 ms) introduced a serious delay.

Tip

Completed Optimization — Fused F16 Mirroring: To zero out this dispatch load, the cast_f32_to_f16 steps have been fused into the adamw_update_f16 (or adamw_8bit_update_f16) kernels of the optimizer stage. While the adamw_update_f16 update kernel updates the fp32 master weights, it also writes the forward fp16 mirror directly (into the dst_w16 buffer) within the same dispatch/pass. This way the standalone cast_f32_to_f16 loop and the 110 separate dispatch overheads have been eliminated entirely.


WGSL-Specific Notes

1. enable f16; required

This file is loaded only when the shader-f16 feature is present. In engine.js:

javascript
if (!adapter.features.has('shader-f16')) {
    throw new Error('...');
}

enable f16; is injected into the shared preamble → the f16 type is usable in every kernel. Safari has no shader-f16 → this file is not loaded, mixed precision mode disabled.

2. clamp(NaN, -c, c)

The behavior is unspecified in the WGSL spec. On Apple Metal it is in practice c (the nearest bound). In practice it is not an issue — gradient clip and z_loss already prevent NaN.

3. f16 storage layout

array<f16> is 16-bit storage — fp16 max 65504. Our weights are below ~5, abundant headroom.

WebGPU spec: f16 storage is packed (2 bytes per element, no padding). On Apple Metal hardware it is identical to half — there are native f16x4 instructions. There is also hardware support on NVIDIA/Intel desktop.


Code Review

Finding 1: Clamp every element — wasted bandwidth?

RiskDescription
🟡 minorIn healthy training no clamp ever triggers. Still, every element does 2 comparisons + 2 selects. Marginal CPU overhead. In practice, since it is memory-bound, the extra ALU load is ignored.

Finding 2: Standalone vs fused

RiskDescription
🟢 ResolvedThe standalone cast step required 110 dispatches/step. This operation has been fused into adamw_update_f16, zeroing out the extra dispatch load and yielding a ~1.5-2.0% step time improvement.

Finding 3: Is cast_f16_to_f32 used?

RiskDescription
🟢 noneCurrently not called — the forward kernels do an inline f32(weight_w16) at tile-load. The standalone kernel stays around for future use. No bandwidth cost.

Quick Checklist

Test ScenarioStatus
Does clamp(weight = 1000) clamp correctly?✅ formula ±65504
Is f16(weight) → f32 lossless?✅ widening cast
Is f16(weight = 1.0) → f32 = 1.0 exact?✅ powers of 2 exact
Is the f16 mirror periodically synced with the master?✅ every step after adamw_update
Does it not load on Safari?shader-f16 feature gate

Next

09_backward_ffn.md — FFN activation backwards: GeLU + SwiGLU.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.