llm.istanbul·Study
TR EN
Workbench →

rope_q and rope_k — Rotary Position Embedding (Forward)

File: 04_rope.wgsl Pipeline step: Right after the Q and K projection, right before attention.

Forward kernels:

  • rope_q — rotation applied to Query
  • rope_k — rotation applied to Key (with n_kv_heads, for GQA)

The backward kernels (rope_q_backward, rope_k_backward) are covered in 10_backward_attention.md — since they work together.


Wait, what is this?

Say you've got a sentence and the model needs to know where each word sits in it. "The dog bit the man" and "The man bit the dog" are the same words, but the order flips the meaning entirely. Here's the catch: attention dumps the word vectors into a bag and compares them against each other — and the bag has no notion of order. You have to somehow smuggle position into the vector itself.

RoPE's trick is this: picture each vector as the hand of a clock. The further along a token sits in the sentence, the more you spin that hand. Word 1 turns a little, word 5 turns more, word 50 turns a lot. The angle itself encodes the position — there's no separate position table, you embed the info by simply turning the vector.

The neat part shows up at comparison time. When you line up two words' clock hands (the dot product attention runs), what matters isn't the absolute direction of either hand but the angle between them — i.e. how many positions apart they are. Spin both hands by the same extra amount and the gap between them is unchanged. So the model picks up "these two words are 4 steps apart" for free, whether they sit at the start of the sentence or the end.

Splitting the vector into pairs and spinning each pair at a different speed (some slow, some fast tick-tick) instead of one single hand is the part that gets a bit dizzying — but the gist holds: position info = amount of rotation.


What Does It Do?

LLaMA-style rotary position embedding. It rotates each consecutive (d, d+1) pair of the Q and K vectors with a 2D rotation based on the token's position. This is how position information is fed into the model — the positional signal lives in the vector itself, with no extra embedding table.

The math, following the classic logic:

Pair i (içinde head, d_pair = 2i, 2i+1):
  freq  = base^(-2i / head_dim)             ← pair index'e göre frekans
  angle = (s + pos_offset) · freq           ← position'a göre açı
  q[2i]'   = q[2i] · cos(angle) - q[2i+1] · sin(angle)
  q[2i+1]' = q[2i] · sin(angle) + q[2i+1] · cos(angle)

This is a 2D rotation matrix dot product:

[q'[2i]  ]   [ cos -sin]   [q[2i]  ]
[q'[2i+1]] = [ sin  cos] · [q[2i+1]]

The even-numbered pairs i = 0, 1, 2, ... each sit at a different frequency:

  • i=0 (lowest dim): freq=1 → low-frequency rotation, slow position-axis traversal
  • i=head_dim/2-1 (highest dim): freq=base^(-1+2/head_dim) → high-frequency

Thanks to this spectrum, the model captures relative position naturally inside the attention dot product:

Q_s · K_s' = (rotated_q_s) · (rotated_k_s') = q_unrotated · k_unrotated · f(s - s')

In other words, the Q⋅K computation is independent of absolute position and becomes purely a function of the relative position s−s'. Detail: RoFormer paper.


Bind Group ABI

rope_q (3 bindings)

BindingTypeDetail
0storage, read_writeQ: array<f32>[seq_len × n_heads × head_dim] row-major, in-place rotate
1uniformdims: vec4<u32>(seq_len, n_heads, head_dim, pos_offset)
2uniformparams: vec4<f32>(rope_base, _, _, _)

rope_k (3 bindings)

Same thing; n_kv_heads in place of n_heads. For GQA, n_kv_heads < n_heads.

Why pos_offset? In the decoder use-case, when you continue with a KV cache you pass through pos_offset the position offset corresponding to the tokens in the cache. In training, pos_offset = 0.

What is rope_base?

By default in LLaMA-2 it is 10000. It sets the "spread" of the frequency spectrum — a larger base = a wider low-frequency range = captures a longer maximum sequence length.

In LLaMA-3 and Mistral, as the context window grows, base = 50000 or 1000000 is being tried. This kernel treats base as a runtime parameter, so it is adjustable.


Dispatch Shape

One thread = one (s, h, p) triple (s: seq pos, h: head, p: pair index).

threads per kernel:  seq_len × n_heads × (head_dim/2)
workgroup_size:      256
total WG:            ceil(threads / 256)

Example (seq=512, n_heads=12, head_dim=64):

  • pairs per token = n_heads × head_dim/2 = 12 × 32 = 384
  • total threads = 512 × 384 = 196,608
  • WG = 768

Line by Line — rope_q

1) Index decode

wgsl
let i = flat_id(gid, nwg);
let half_d = head_dim / 2u;
let pairs_per_pos = n_heads * half_d;
let total = seq_len * pairs_per_pos;
if (i >= total) { return; }

let s = i / pairs_per_pos;
let rest = i % pairs_per_pos;
let h = rest / half_d;
let p = rest % half_d;

Flat index → 3D (s, h, p). p = pair index within the head (0..head_dim/2-1).

2) Compute the frequency

wgsl
let freq = rope_freq(p, head_dim, base);
// freq = exp(-log(base) * 2*p / head_dim) = base^(-2p/head_dim)

The exp(-log(base) * 2p/d) formula is the same as base^(-2p/d), but it is uncertain whether WGSL's pow(b, x) would be more stable per the spec. exp(-log(base) * x) always works.

3) Angle + sin/cos

wgsl
let angle = f32(s + pos_offset) * freq;
let c = cos(angle);
let sn = sin(angle);

pos_offset = 0 in training. In the decoder it is the KV cache offset.

cos/sin are WGSL built-ins. On the Apple GPU there is hardware acceleration (fast_sin, fast_cos instruction).

4) Pair indices

wgsl
let row_base = s * (n_heads * head_dim);
let head_base = row_base + h * head_dim;
let i0 = head_base + 2u * p;
let i1 = i0 + 1u;

Q[i0] = Q[s, h, 2p], Q[i1] = Q[s, h, 2p+1].

5) Rotation in-place

wgsl
let q0 = Q[i0];
let q1 = Q[i1];
Q[i0] = q0 * c - q1 * sn;
Q[i1] = q0 * sn + q1 * c;

A classic 2D rotation. Copy q0 and q1 into registers first, then write — otherwise, once you write Q[i0] and then compute Q[i1], q0 would no longer hold the old value but the new one, a race condition with itself.

Why in-place? Q and K are large tensors (seq × n_heads × d_head). If we used a separate output buffer, bandwidth would double. RoPE is an element-wise rotation, not fused inside mainpass — it is a separate kernel, but thanks to being in-place, bandwidth is not an issue.


How rope_k Differs

Only the number of heads:

  • rope_q n_heads (e.g. 12)
  • rope_k n_kv_heads (e.g. 4 — GQA)

Everything else is the same. Q and K are rotated in parallel — same angle (same s), same pair index, same frequency. This equality is critical: when the Q·K dot product is computed in attention, the same-pair rotations cancel each other out.

Q_s = R(s) · q_unrotated
K_s' = R(s') · k_unrotated
Q_s · K_s' = q · R(s)^T · R(s') · k = q · R(s' - s) · k

R(s'-s) depends only on the relative offset s'-s. This is RoPE's main strength.


RoPE Backward — Important note

In my WGSL code the backward kernels live inside 10_backward_attention.wgsl:

  • rope_q_backward (line 499)
  • rope_k_backward (line 544)

Backward formula:

∂L/∂q[2i]   = (∂L/∂q'[2i])  · cos + (∂L/∂q'[2i+1]) · sin
∂L/∂q[2i+1] = (∂L/∂q'[2i+1]) · cos - (∂L/∂q'[2i])  · sin

So it is the transpose of the forward (= the inverse of the rotation matrix = the same matrix with -angle). Detail in the 10_backward_attention.md section.


WGSL-Specific Notes

1. cos(angle) precision

WGSL cos/sin are IEEE round-to-nearest. On Apple Metal, if fast-math is on the precision may drop, but for our use it is sufficient (relative position rotation; a small floating error shows up as noise in the attention output, and the gradient signal-to-noise ratio is still high).

2. exp and log overhead

Every thread computes exp(-log(base) * 2p/d)p differs, but log(base) is constant. The compiler can hoist log(base), but it is not guaranteed. Alternative: pass a precomputed table from the host (e.g. a freq[p] lookup table). Our code computes it at runtime — for simplicity.

3. In-place write race

wgsl
let q0 = Q[i0];
let q1 = Q[i1];
Q[i0] = q0 * c - q1 * sn;  // safe — q0/q1 already in registers
Q[i1] = q0 * sn + q1 * c;

If I had done Q[i0] = ... first and then q1 = Q[i1], I would have lost the old value of q0 (I would be using the new value I just wrote). The correct ordering: two reads, then two writes.


Code Review

Finding 1: freq lookup table optimization opportunity

RiskDescription
🟢 none, but optimEvery thread that sees the same p computes the same freq. Within a workgroup, n_heads × seq_len threads share the same p. A lookup table freq[head_dim/2] precomputed on the host → loaded into workgroup memory → constant time. A marginal speedup, but elegant.

Finding 2: pos_offset is always 0 (training)

RiskDescription
🟢 noneIt exists for decode/inference. In training it is a needless add (s + 0u), but the compiler optimizes it away.

Quick Checklist

Test ScenarioStatus
Do Q and K compute the same angle for the same s and p?✅ identical formula
Is GQA n_kv_heads passed correctly?✅ separate dims.y
Is pos_offset consistent with the KV cache in the decoder?✅ inference logic
Does rope_base = 10000 (LLaMA default) work?
Forward + backward identity test (Q → fwd → bwd → Q)?⚠ no formal test, the only test is attention end-to-end
head_dim even-number guarantee?⚠ the kernel does not check; if odd, the last dim (head_dim-1) is skipped

Next

05_attention.md — the conceptual heart of the model: scaled dot-product attention. RoPE'd Q and K, together with V.

WGSL kernel studies · an LLM from scratch on WebGPUBuilt in Istanbul by Uğur Toprakdeviren.