matmul_t, matmul_t_acc, matmul_at, matmul_at_acc, matmul_at_swiglu_a, matmul_at_acc_swiglu_a — Backward Matrix Multiply
Dosya: 03_linear.wgsl (backward kernel'ları) Pipeline adımı: Backward'ın bel kemiği. Backward step time'ın ~%35'i bu kernel'larda harcanır.
6 kernel:
matmul_t—Y = X @ W^T(dX hesaplaması için)matmul_t_acc—Y += X @ W^T(gradient accumulation için)matmul_at—Y = X^T @ B(dW hesaplaması için)matmul_at_acc—Y += X^T @ Bmatmul_at_swiglu_a—Y = (silu(GATE) * UP)^T @ B(FFN SwiGLU dW_down weight gradyanı hesaplaması)matmul_at_acc_swiglu_a—Y += (silu(GATE) * UP)^T @ B
_w16 mixed-precision varyantları (örn. matmul_t_w16) aynı dosyada yer alır.
Nedir bu ya?
Diyelim bir ses mikseri kullanıyorsun. Girişten gelen sinyal (X) bir sürü potansiyometreden (W ağırlıkları) geçiyor, her kanal belli oranda karışıp çıkışı (Y) üretiyor. Forward pass tam olarak bu: Y = X @ W. Ama sonra çıkışı dinliyorsun ve "olmamış, şu kadar bozuk" diyorsun — işte o hata sinyali dY.
Şimdi geri dönüp iki ayrı soruya cevap vermen lazım, ve bunlar gerçekten ayrı sorular. Birincisi: "potları ne yöne çevirmeliyim?" Bu dW = X^T @ dY. Yani her pot için, ona giren sinyalle çıkıştaki hatayı eşleştirip "bu potu kıstığımda hata düşer mi artar mı" bilgisini üretiyorsun. İkincisi: "bir önceki kişiye ne haber vereyim?" Çünkü senin girişin aslında başka bir katmanın çıkışı. Bu da dX = dY @ W^T — hatayı potların üzerinden geriye, kaynağa doğru itiyorsun.
İşin sezgisel kısmı şu: forward'da çarpım tek yönlü bir karışım. Geriye dönerken her soruda farklı bir matrisi transpoze ediyorsun — dW için X'i, dX için W'yi — çünkü iki soru iki farklı "açıdan bakma" istiyor. Kod açısından bu, bildiğin matris çarpımının indislerini takas etmekten ibaret — ama GPU'da bellek düzeni (kim hangi sırayla okuyor) her şeyi belirlediği için, bu iki transpoze deseni ayrı ayrı kernel olmayı hak ediyor.
Aşağıda göreceğin 6 kernel de bu iki sorunun ve birkaç pratik varyantın (gradyan biriktirme, SwiGLU füzyonu) somut hâli. Hepsi aynı temel hesabı yapıyor; fark, hangi matrisin ters çevrildiği ve belleğin nasıl okunduğu.
Niye 6 Ayrı Kernel?
Forward pass Y = X @ W için tek kernel (matmul) yeterliyken, backward'da iki ayrı gradyan hesaplamamız gerekir:
dX = dY @ W^T— bir önceki katmana hata gradyanını iletmek için (transposed weight)dW = X^T @ dY— ağırlık güncellemelerini hesaplamak için (transposed input)
Bu işlemler farklı transpose desenleri içerdiğinden, bellek düzenini (memory layout) bozmadan forward kernel'ı ile hesaplanamazlar:
matmul: Y[M, N] = X[M, K] @ W[K, N] ← W normal
matmul_t: Y[M, N] = X[M, K] @ W^T[N, K] ← W transposed
matmul_at: Y[K, N] = X^T[M, K]^T @ B[M, N] ← X transposed_acc varyantları, gradyan biriktirme (gradient accumulation) için += semantiği sağlar.
_swiglu_a varyantları ise FFN SwiGLU forward fusion ile mükemmel bir uyum içerisindedir. Forward sırasında bellekte hidden aktivasyon matrisi (silu(gate)*up) oluşturulmadığından, backward katmanında dW_down gradyanını hesaplarken bu kernel doğrudan gate ve up girdilerini yükler, silu(gate)*up birleşimini on-the-fly (hesaplama anında) türeterek matris çarpımını tamamlar. Ara matrisin belleğe yazılıp okunması tamamen engellenmiş olur.
Vektörize Yükleme ve Çakışmasız Adresleme (Vec4 & Coalesced Transpose Load)
- Vec4 Geçişi: Girdiler
array<vec4<f32>>olarak bağlanarak 16-byte'lık tek yönergede 4 float çekilir.K % 4 == 0veN % 4 == 0kısıtları mevcuttur. - Coalesced Load Swap: Transpose matris çarpımında ardışık thread'lerin bellek üzerinde sıçramalar (stride) yapmasını engellemek için yükleme indis eşlemeleri (
aIm,aIk) takas edilir:matmul_at: X^T üzerinden inner reduction yapar (M ekseninde). Swap sayesinde ardışık thread'ler sabit K'da ardışık M değerlerini yükler (stride-K sıçramaları yerine ardışık 4'lü row reads). Bu sayede donanım L1 cache line'ı tam coalesced şekilde doldurulur (Warp başına 32 yerine 2-4 cache line yükü).
Bind Group ABI
matmul_t (4 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | X: array<vec4<f32>> — [M × K/4] |
| 1 | storage, read | W: array<vec4<f32>> — [N × K/4] (transposed view) |
| 2 | storage, read_write | Y: array<f32> — [M × N] |
| 3 | uniform | dims: vec4<u32> — (M, N, K, _) |
matmul_at (4 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | X: array<vec4<f32>> — [M × K/4] |
| 1 | storage, read | B: array<vec4<f32>> — [M × N/4] (dY upstream gradient) |
| 2 | storage, read_write | Y: array<f32> — [K × N] (dW output) |
| 3 | uniform | dims: vec4<u32> — (K, N, M, _) |
matmul_at_swiglu_a (5 binding)
| Binding | Tür | Detay |
|---|---|---|
| 0 | storage, read | GATE_mas: array<vec4<f32>> — forward gate çıktısı |
| 1 | storage, read | UP_mas: array<vec4<f32>> — forward up çıktısı |
| 2 | storage, read | B_mas: array<vec4<f32>> — dY upstream gradient |
| 3 | storage, read_write | Y_mas: array<f32> — dW_down output |
| 4 | uniform | dims_mas: vec4<u32> — (K, N, M, _) |
Dispatch Şekli
workgroup_size: (16, 16, 1) → 256 threads
matmul_t / matmul_t_acc: grid = (ceil(N/64), ceil(M/64), 1)
matmul_at / matmul_at_swiglu_a: grid = (ceil(N/64), ceil(K/64), 1)matmul_at — En Pahalı Backward Kernel'ı
Profil snapshot'larında matmul_at en yoğun kernel'dır:
- Maliyet: Toplam backward adım süresinin ~%18.4'ü tek başına bu kernel grubunda geçer.
- FFN Boyutları: Özellikle FFN ağırlık matrislerinin (768×3072) backward gemm işlemlerinde darboğaz oluşturur. SwiGLU fusion (
matmul_at_swiglu_a) bu maliyeti bellek transferlerini eleyerek dramatik ölçüde düşürür.
Satır Satır — matmul_at_swiglu_a Coalescing ve Hesaplama
matmul_at_swiglu_a kernel'ında M üzerinden inner-reduction ve on-the-fly SwiGLU çözümü:
1) A-Tile Coalesced Yüklemede İndis Takası (Swap)
{
let aI0 = tid * 4u;
let aIm = aI0 / TM; let aIk = aI0 % TM;
let kG = block_row + aIk; let mG = aIm;
let m_in = mG < M;
// gate ve up verilerini okur
let gv = GATE_mas[mG * K4 + kG / 4u];
let uv = UP_mas [mG * K4 + kG / 4u];
// SwiGLU'yu hesaplayarak A tile'ına coalesced (ardışık) yazar
tileA_db[(aIk + 0u) * TK_PAD + aIm] = select(0.0, silu(gv.x) * uv.x, m_in && (kG + 0u) < K);
tileA_db[(aIk + 1u) * TK_PAD + aIm] = select(0.0, silu(gv.y) * uv.y, m_in && (kG + 1u) < K);
tileA_db[(aIk + 2u) * TK_PAD + aIm] = select(0.0, silu(gv.z) * uv.z, m_in && (kG + 2u) < K);
tileA_db[(aIk + 3u) * TK_PAD + aIm] = select(0.0, silu(gv.w) * uv.w, m_in && (kG + 3u) < K);2) Inner Reduction ve FMA Döngüsü
Döngünün inner boyutu M (satır adedi) üzerindedir:
for (var m: u32 = 0u; m < TK; m = m + 1u) {
let a0 = tileA_db[cur_a_off + (4u * ty + 0u) * TK_PAD + m];
let a1 = tileA_db[cur_a_off + (4u * ty + 1u) * TK_PAD + m];
let a2 = tileA_db[cur_a_off + (4u * ty + 2u) * TK_PAD + m];
let a3 = tileA_db[cur_a_off + (4u * ty + 3u) * TK_PAD + m];
let b0 = tileB_db[cur_b_off + m * TN + (4u * tx + 0u)];
let b1 = tileB_db[cur_b_off + m * TN + (4u * tx + 1u)];
let b2 = tileB_db[cur_b_off + m * TN + (4u * tx + 2u)];
let b3 = tileB_db[cur_b_off + m * TN + (4u * tx + 3u)];
acc00 = fma(a0, b0, acc00); acc01 = fma(a0, b1, acc01); acc02 = fma(a0, b2, acc02); acc03 = fma(a0, b3, acc03);
acc10 = fma(a1, b0, acc10); acc11 = fma(a1, b1, acc11); acc12 = fma(a1, b2, acc12); acc13 = fma(a1, b3, acc13);
acc20 = fma(a2, b0, acc20); acc21 = fma(a2, b1, acc21); acc22 = fma(a2, b2, acc22); acc23 = fma(a2, b3, acc23);
acc30 = fma(a3, b0, acc30); acc31 = fma(a3, b1, acc31); acc32 = fma(a3, b2, acc32); acc33 = fma(a3, b3, acc33);
}Code Review
Bulgu 1: Backward FP32 Hassasiyeti Zorunludur
| Risk | Açıklama |
|---|---|
| 🟢 mimari | Backward adımında fp16 kullanılmamıştır. Gradyanlar fp16 olsa Adam momentleri gürültülü (noisy) birikeceğinden training stabilitesi bozulurdu. Tasarım gereği: forward f16 storage iken, backward ve optimizer adımları saf fp32'dir. |
Hızlı Kontrol Listesi
| Test Senaryosu | Durum |
|---|---|
matmul_at ve matmul_at_swiglu_a çıktı boyutları doğru mu? | ✅ [K × N] (dW output) |
| Coalesced swap donanımsal olarak çalışıyor mu? | ✅ Apple Instruments ile doğrulandı |
SwiGLU backward fusion ara hidden gereksinimini kaldırdı mı? | ✅ Evet, doğrudan gate ve up verilerinden çözüldü |
Sonraki
12_optimizer.md — AdamW. Multi-tensor mega-buffer ve fused fp16 aynalama optimizasyonları. Pipeline'ın final adımı.