rope_q and rope_k — Rotary Position Embedding (Forward)
File: 04_rope.wgsl Pipeline step: Right after the Q and K projection, right before attention.
Forward kernels:
rope_q— rotation applied to Queryrope_k— rotation applied to Key (with n_kv_heads, for GQA)
The backward kernels (rope_q_backward, rope_k_backward) are covered in 10_backward_attention.md — since they work together.
Wait, what is this?
Say you've got a sentence and the model needs to know where each word sits in it. "The dog bit the man" and "The man bit the dog" are the same words, but the order flips the meaning entirely. Here's the catch: attention dumps the word vectors into a bag and compares them against each other — and the bag has no notion of order. You have to somehow smuggle position into the vector itself.
RoPE's trick is this: picture each vector as the hand of a clock. The further along a token sits in the sentence, the more you spin that hand. Word 1 turns a little, word 5 turns more, word 50 turns a lot. The angle itself encodes the position — there's no separate position table, you embed the info by simply turning the vector.
The neat part shows up at comparison time. When you line up two words' clock hands (the dot product attention runs), what matters isn't the absolute direction of either hand but the angle between them — i.e. how many positions apart they are. Spin both hands by the same extra amount and the gap between them is unchanged. So the model picks up "these two words are 4 steps apart" for free, whether they sit at the start of the sentence or the end.
Splitting the vector into pairs and spinning each pair at a different speed (some slow, some fast tick-tick) instead of one single hand is the part that gets a bit dizzying — but the gist holds: position info = amount of rotation.
What Does It Do?
LLaMA-style rotary position embedding. It rotates each consecutive (d, d+1) pair of the Q and K vectors with a 2D rotation based on the token's position. This is how position information is fed into the model — the positional signal lives in the vector itself, with no extra embedding table.
The math, following the classic logic:
Pair i (içinde head, d_pair = 2i, 2i+1):
freq = base^(-2i / head_dim) ← pair index'e göre frekans
angle = (s + pos_offset) · freq ← position'a göre açı
q[2i]' = q[2i] · cos(angle) - q[2i+1] · sin(angle)
q[2i+1]' = q[2i] · sin(angle) + q[2i+1] · cos(angle)This is a 2D rotation matrix dot product:
[q'[2i] ] [ cos -sin] [q[2i] ]
[q'[2i+1]] = [ sin cos] · [q[2i+1]]The even-numbered pairs i = 0, 1, 2, ... each sit at a different frequency:
i=0(lowest dim): freq=1 → low-frequency rotation, slow position-axis traversali=head_dim/2-1(highest dim): freq=base^(-1+2/head_dim) → high-frequency
Thanks to this spectrum, the model captures relative position naturally inside the attention dot product:
Q_s · K_s' = (rotated_q_s) · (rotated_k_s') = q_unrotated · k_unrotated · f(s - s')In other words, the Q⋅K computation is independent of absolute position and becomes purely a function of the relative position s−s'. Detail: RoFormer paper.
Bind Group ABI
rope_q (3 bindings)
| Binding | Type | Detail |
|---|---|---|
| 0 | storage, read_write | Q: array<f32> — [seq_len × n_heads × head_dim] row-major, in-place rotate |
| 1 | uniform | dims: vec4<u32> — (seq_len, n_heads, head_dim, pos_offset) |
| 2 | uniform | params: vec4<f32> — (rope_base, _, _, _) |
rope_k (3 bindings)
Same thing; n_kv_heads in place of n_heads. For GQA, n_kv_heads < n_heads.
Why pos_offset? In the decoder use-case, when you continue with a KV cache you pass through
pos_offsetthe position offset corresponding to the tokens in the cache. In training,pos_offset = 0.
What is rope_base?
By default in LLaMA-2 it is 10000. It sets the "spread" of the frequency spectrum — a larger base = a wider low-frequency range = captures a longer maximum sequence length.
In LLaMA-3 and Mistral, as the context window grows, base = 50000 or 1000000 is being tried. This kernel treats base as a runtime parameter, so it is adjustable.
Dispatch Shape
One thread = one (s, h, p) triple (s: seq pos, h: head, p: pair index).
threads per kernel: seq_len × n_heads × (head_dim/2)
workgroup_size: 256
total WG: ceil(threads / 256)Example (seq=512, n_heads=12, head_dim=64):
- pairs per token =
n_heads × head_dim/2 = 12 × 32 = 384 - total threads =
512 × 384 = 196,608 - WG = 768
Line by Line — rope_q
1) Index decode
let i = flat_id(gid, nwg);
let half_d = head_dim / 2u;
let pairs_per_pos = n_heads * half_d;
let total = seq_len * pairs_per_pos;
if (i >= total) { return; }
let s = i / pairs_per_pos;
let rest = i % pairs_per_pos;
let h = rest / half_d;
let p = rest % half_d;Flat index → 3D (s, h, p). p = pair index within the head (0..head_dim/2-1).
2) Compute the frequency
let freq = rope_freq(p, head_dim, base);
// freq = exp(-log(base) * 2*p / head_dim) = base^(-2p/head_dim)The exp(-log(base) * 2p/d) formula is the same as base^(-2p/d), but it is uncertain whether WGSL's pow(b, x) would be more stable per the spec. exp(-log(base) * x) always works.
3) Angle + sin/cos
let angle = f32(s + pos_offset) * freq;
let c = cos(angle);
let sn = sin(angle);pos_offset = 0 in training. In the decoder it is the KV cache offset.
cos/sin are WGSL built-ins. On the Apple GPU there is hardware acceleration (fast_sin, fast_cos instruction).
4) Pair indices
let row_base = s * (n_heads * head_dim);
let head_base = row_base + h * head_dim;
let i0 = head_base + 2u * p;
let i1 = i0 + 1u;Q[i0] = Q[s, h, 2p], Q[i1] = Q[s, h, 2p+1].
5) Rotation in-place
let q0 = Q[i0];
let q1 = Q[i1];
Q[i0] = q0 * c - q1 * sn;
Q[i1] = q0 * sn + q1 * c;A classic 2D rotation. Copy q0 and q1 into registers first, then write — otherwise, once you write Q[i0] and then compute Q[i1], q0 would no longer hold the old value but the new one, a race condition with itself.
Why in-place? Q and K are large tensors (seq × n_heads × d_head). If we used a separate output buffer, bandwidth would double. RoPE is an element-wise rotation, not fused inside
mainpass— it is a separate kernel, but thanks to being in-place, bandwidth is not an issue.
How rope_k Differs
Only the number of heads:
rope_qn_heads (e.g. 12)rope_kn_kv_heads (e.g. 4 — GQA)
Everything else is the same. Q and K are rotated in parallel — same angle (same s), same pair index, same frequency. This equality is critical: when the Q·K dot product is computed in attention, the same-pair rotations cancel each other out.
Q_s = R(s) · q_unrotated
K_s' = R(s') · k_unrotated
Q_s · K_s' = q · R(s)^T · R(s') · k = q · R(s' - s) · kR(s'-s) depends only on the relative offset s'-s. This is RoPE's main strength.
RoPE Backward — Important note
In my WGSL code the backward kernels live inside 10_backward_attention.wgsl:
rope_q_backward(line 499)rope_k_backward(line 544)
Backward formula:
∂L/∂q[2i] = (∂L/∂q'[2i]) · cos + (∂L/∂q'[2i+1]) · sin
∂L/∂q[2i+1] = (∂L/∂q'[2i+1]) · cos - (∂L/∂q'[2i]) · sinSo it is the transpose of the forward (= the inverse of the rotation matrix = the same matrix with -angle). Detail in the 10_backward_attention.md section.
WGSL-Specific Notes
1. cos(angle) precision
WGSL cos/sin are IEEE round-to-nearest. On Apple Metal, if fast-math is on the precision may drop, but for our use it is sufficient (relative position rotation; a small floating error shows up as noise in the attention output, and the gradient signal-to-noise ratio is still high).
2. exp and log overhead
Every thread computes exp(-log(base) * 2p/d) — p differs, but log(base) is constant. The compiler can hoist log(base), but it is not guaranteed. Alternative: pass a precomputed table from the host (e.g. a freq[p] lookup table). Our code computes it at runtime — for simplicity.
3. In-place write race
let q0 = Q[i0];
let q1 = Q[i1];
Q[i0] = q0 * c - q1 * sn; // safe — q0/q1 already in registers
Q[i1] = q0 * sn + q1 * c;If I had done Q[i0] = ... first and then q1 = Q[i1], I would have lost the old value of q0 (I would be using the new value I just wrote). The correct ordering: two reads, then two writes.
Code Review
Finding 1: freq lookup table optimization opportunity
| Risk | Description |
|---|---|
| 🟢 none, but optim | Every thread that sees the same p computes the same freq. Within a workgroup, n_heads × seq_len threads share the same p. A lookup table freq[head_dim/2] precomputed on the host → loaded into workgroup memory → constant time. A marginal speedup, but elegant. |
Finding 2: pos_offset is always 0 (training)
| Risk | Description |
|---|---|
| 🟢 none | It exists for decode/inference. In training it is a needless add (s + 0u), but the compiler optimizes it away. |
Quick Checklist
| Test Scenario | Status |
|---|---|
Do Q and K compute the same angle for the same s and p? | ✅ identical formula |
| Is GQA n_kv_heads passed correctly? | ✅ separate dims.y |
Is pos_offset consistent with the KV cache in the decoder? | ✅ inference logic |
Does rope_base = 10000 (LLaMA default) work? | ✅ |
| Forward + backward identity test (Q → fwd → bwd → Q)? | ⚠ no formal test, the only test is attention end-to-end |
head_dim even-number guarantee? | ⚠ the kernel does not check; if odd, the last dim (head_dim-1) is skipped |
Next
05_attention.md — the conceptual heart of the model: scaled dot-product attention. RoPE'd Q and K, together with V.