1diagnostic(off, chromium.subgroup_matrix_uniformity);
  2diagnostic(off, subgroup_uniformity);
  3enable f16;
  4enable subgroups;
  5enable chromium_experimental_subgroup_matrix;
  6
  7#ifdef KV_F32
  8#define KV_TYPE f32
  9#else
 10#define KV_TYPE f16
 11#endif
 12
 13// Default values
 14#define HEAD_DIM_QK 64
 15#define HEAD_DIM_V 64
 16
 17// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN
 18// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.
 19#define SG_MAT_M 8
 20#define SG_MAT_N 8
 21#define SG_MAT_K 8
 22
 23// Each workgroup processes one subgroup matrix of Q rows
 24#define Q_TILE SG_MAT_M
 25#define KV_TILE 16
 26#define WG_SIZE 64
 27
 28// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
 29#define KV_BLOCKS (KV_TILE / SG_MAT_N)
 30
 31// Quantization constants/helpers
 32#define BLOCK_SIZE 32
 33#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
 34#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
 35// number of quantized elements processed per thread
 36#if defined(KV_Q4_0)
 37#define NQ 16
 38// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
 39#define F16_PER_BLOCK 9
 40#define WEIGHTS_PER_F16 4
 41#elif defined(KV_Q8_0)
 42#define NQ 8
 43// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
 44#define F16_PER_BLOCK 17
 45#define WEIGHTS_PER_F16 2
 46#endif
 47#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
 48
 49// Ok not to put these in a define block, compiler will remove if unused
 50fn get_byte(value: u32, index: u32) -> u32 {
 51    return (value >> (index * 8)) & 0xFF;
 52}
 53
 54fn get_byte_i32(value: u32, index: u32) -> i32 {
 55    return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
 56}
 57
 58struct Params {
 59    offset_q: u32,
 60    offset_k: u32,
 61    offset_v: u32,
 62    offset_mask: u32,
 63    offset_sinks: u32,
 64    offset_dst: u32,
 65
 66    // shapes of Q/K/V
 67    n_heads: u32,
 68    seq_len_q: u32,
 69    seq_len_kv: u32,
 70
 71    // strides (in elements)
 72    stride_q1: u32,
 73    stride_q2: u32,
 74    stride_q3: u32,
 75    stride_k1: u32,
 76    stride_k2: u32,
 77    stride_k3: u32,
 78    stride_v1: u32,
 79    stride_v2: u32,
 80    stride_v3: u32,
 81    stride_mask3: u32,
 82
 83    // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
 84    q_per_kv: u32,
 85
 86    // softmax params
 87    scale: f32,
 88    max_bias: f32,
 89    logit_softcap: f32,
 90    n_head_log2: f32,
 91    m0: f32,
 92    m1: f32,
 93};
 94
 95@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
 96@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
 97@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
 98
 99#if defined(MASK) && defined(SINKS)
100@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
101@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
102#define DST_BINDING 5
103#define PARAMS_BINDING 6
104#elif defined(MASK)
105@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
106#define DST_BINDING 4
107#define PARAMS_BINDING 5
108#elif defined(SINKS)
109@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
110#define DST_BINDING 4
111#define PARAMS_BINDING 5
112#else
113#define DST_BINDING 3
114#define PARAMS_BINDING 4
115#endif
116
117@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
118@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
119
120// Just a very small float value.
121const FLOAT_MIN: f32 = -1.0e9;
122
123// The number of Q rows processed per workgroup
124var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
125
126#ifndef KV_DIRECT
127const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
128// we can reuse the same shmem for K and V since we only need one at a time
129var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
130#endif
131
132var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem
133
134#ifdef MASK
135// storage for mask values
136var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
137#endif
138
139// storage for output of Q*K^T scores for online softmax (S matrix from paper)
140// also storage for diagonal matrix during online softmax (P matrix from paper)
141// note that we reuse the same storage for both since we only need one at a time
142var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
143
144// Storage for row max and exp sum during online softmax
145var<workgroup> row_max_shmem: array<f32, Q_TILE>;
146var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
147
148fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
149    var v = select(FLOAT_MIN,
150                   f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
151                   kv_idx < KV_TILE);
152#ifdef LOGIT_SOFTCAP
153    v = params.logit_softcap * tanh(v);
154#endif
155#ifdef MASK
156    let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
157    let mask_term = slope * mask_val;
158    v += mask_term;
159#endif
160    return v;
161}
162
163fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
164    return (*buf)[scalar_index >> 2u];
165}
166
167fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
168    return (*buf)[scalar_index >> 2u];
169}
170
171@compute @workgroup_size(WG_SIZE)
172fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
173    @builtin(local_invocation_id) local_id: vec3<u32>,
174    @builtin(subgroup_id) subgroup_id: u32,
175    @builtin(subgroup_size) subgroup_size: u32,
176    @builtin(num_subgroups) num_subgroups: u32,
177    @builtin(subgroup_invocation_id) sg_inv_id: u32) {
178
179    // initialize row max for online softmax
180    for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
181        row_max_shmem[i] = FLOAT_MIN;
182        exp_sum_shmem[i] = 0.0;
183    }
184
185    for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
186        o_shmem[i] = 0.0;
187    }
188
189    // workgroups per head/batch
190    let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
191    let wg_per_batch = wg_per_head * params.n_heads;
192
193    let dst2_stride = HEAD_DIM_V * params.n_heads;
194    let dst3_stride = dst2_stride * params.seq_len_q;
195
196    // batch index
197    let batch_idx = wg_id.x / wg_per_batch;
198    let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
199    let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
200    let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
201    let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
202    let wg_in_batch = wg_id.x % wg_per_batch;
203
204    // head index
205    let head_idx = wg_in_batch / wg_per_head;
206    let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
207    let k_head_idx = head_idx / params.q_per_kv;
208    let v_head_idx = k_head_idx;
209    let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
210    let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
211
212    // starting Q row for this workgroup
213    let wg_in_head = wg_in_batch % wg_per_head;
214    let q_row_start = wg_in_head * Q_TILE;
215
216#ifdef MASK
217    // mask offset
218    let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
219#endif
220
221    // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]
222    let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
223
224    let head = f32(head_idx);
225    let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);
226
227    // load q tile into shared memory
228    for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
229        let q_row = elem_idx / HEAD_DIM_QK;
230        let q_col = elem_idx % HEAD_DIM_QK;
231        let head_q_row = q_row_start + q_row;
232        let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
233        q_shmem[elem_idx] = f16(select(
234            0.0,
235            Q[global_q_row_offset + q_col],
236            head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
237    }
238
239    for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
240      // clear inter_shmem to ensure zero-initialized accumulators
241        for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
242            inter_shmem[elem_idx] = 0.0;
243        }
244
245      // load k tile into shared memory
246#if defined(KV_Q4_0)
247      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
248          let blck_idx = elem_idx / BLOCK_SIZE;
249          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
250          let k_row = blck_idx / BLOCKS_K;
251          let global_k_row = kv_tile + k_row;
252          let block_k = blck_idx % BLOCKS_K;
253          let row_offset = k_row * HEAD_DIM_QK;
254
255          if (global_k_row < params.seq_len_kv) {
256              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
257              let base_idx = global_block_idx * F16_PER_BLOCK;
258              let d = K[base_idx]; // scale
259              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
260                  let q_0 = K[base_idx + 1u + block_offset + j];
261                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
262                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
263                  for (var k = 0u; k < 4u; k++) {
264                      let q_byte = get_byte(q_packed, k);
265                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
266                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
267                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
268                      kv_shmem[row_offset + idx] = q_lo;
269                      kv_shmem[row_offset + idx + 16u] = q_hi;
270                  }
271              }
272          }
273      }
274#elif defined(KV_Q8_0)
275      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
276          let blck_idx = elem_idx / BLOCK_SIZE;
277          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
278          let k_row = blck_idx / BLOCKS_K;
279          let global_k_row = kv_tile + k_row;
280          let block_k = blck_idx % BLOCKS_K;
281          let row_offset = k_row * HEAD_DIM_QK;
282
283          if (global_k_row < params.seq_len_kv) {
284              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
285              let base_idx = global_block_idx * F16_PER_BLOCK;
286              let d = K[base_idx]; // scale
287              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
288                  let q_0 = K[base_idx + 1u + block_offset + j];
289                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
290                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
291                  for (var k = 0u; k < 4u; k++) {
292                      let q_byte = get_byte_i32(q_packed, k);
293                      let q_val = f16(q_byte) * d;
294                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
295                      kv_shmem[row_offset + idx] = q_val;
296                  }
297              }
298          }
299      }
300#elif defined(KV_DIRECT)
301      // Direct global loads for KV
302#else
303      for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
304          let k_row = elem_idx / HEAD_DIM_QK;
305          let k_col = elem_idx % HEAD_DIM_QK;
306          let global_k_row = kv_tile + k_row;
307          let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
308          kv_shmem[elem_idx] = f16(select(
309              0.0,
310              K[global_k_row_offset + k_col],
311              global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
312      }
313#endif
314
315      workgroupBarrier();
316
317      // accumulate q block * k block into registers across the entire KV tile
318      // TODO: this loop seems to be the current largest bottleneck
319      // this bracket exists to scope the lifetime of variables, reducing register pressure
320      {
321#ifdef KV_DIRECT
322          let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
323          var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
324#else
325          var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
326#endif
327          for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
328              let inter_offset = kv_block * SG_MAT_N;
329              var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
330
331              var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
332
333#ifdef KV_DIRECT
334              var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
335#else
336              var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
337#endif
338
339              var t: u32 = 1u;
340              for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
341                  let h0 = t * SG_MAT_K;
342                  var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
343#ifdef KV_DIRECT
344                  var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
345#else
346                  var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
347#endif
348                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
349                  q_cur = q0;
350                  k_cur = k0;
351
352                  let h1 = (t + 1u) * SG_MAT_K;
353                  var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
354#ifdef KV_DIRECT
355                  var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
356#else
357                  var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
358#endif
359                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
360                  q_cur = q1g;
361                  k_cur = k1g;
362              }
363
364              // handle odd tail
365              if (t < HEAD_DIM_QK / SG_MAT_K) {
366                  let h = t * SG_MAT_K;
367                  var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
368#ifdef KV_DIRECT
369                  var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
370#else
371                  var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
372#endif
373                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
374                  q_cur = qn;
375                  k_cur = kn;
376              }
377
378              acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
379
380#ifdef KV_DIRECT
381              k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
382#else
383              k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
384#endif
385              subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
386          }
387      }
388
389
390#ifdef MASK
391      // load mask tile into shared memory for this KV block
392      // TODO: optimize and skip if mask is -INF for the entire tile
393      for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
394          let mask_row = elem_idx / KV_TILE;
395          let mask_col = elem_idx % KV_TILE;
396          let global_q_row = q_row_start + mask_row;
397          let global_k_col = kv_tile + mask_col;
398          let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
399          let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
400          mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
401      }
402#endif
403
404      workgroupBarrier();
405
406      // online softmax
407      for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
408          let global_q_row = q_row_start + q_tile_row;
409          if (global_q_row >= params.seq_len_q) {
410              break;
411          }
412
413          // initialize running max for this row
414          var prev_max = row_max_shmem[q_tile_row];
415          var final_max = prev_max;
416          // pass 1: compute final max across the full KV tile in chunks
417          for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
418              let kv_idx = kv_offset + sg_inv_id;
419              let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
420              final_max = subgroupMax(max(final_max, softmax_term));
421          }
422
423          var total_exp_term: f32 = 0.0;
424          // pass 2: compute exp sum and write P using final_max
425          for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
426              let kv_idx = kv_offset + sg_inv_id;
427              let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
428              let cur_p = select(0.0,
429                                 exp(softmax_term - final_max),
430                                 kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
431              total_exp_term += subgroupAdd(cur_p);
432              if (kv_idx < KV_TILE) {
433                  inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
434              }
435          }
436
437          let cur_exp = exp(prev_max - final_max);
438
439          if (sg_inv_id == 0) {
440              row_max_shmem[q_tile_row] = final_max;
441              exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
442          }
443
444          for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
445              let idx = q_tile_row * HEAD_DIM_V + elem_idx;
446              o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
447          }
448      }
449
450      // load v tile into shared memory
451#if defined(KV_Q4_0)
452      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
453          let blck_idx = elem_idx / BLOCK_SIZE;
454          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
455          let v_row = blck_idx / BLOCKS_V;
456          let global_v_row = kv_tile + v_row;
457          let block_k = blck_idx % BLOCKS_V;
458          let row_offset = v_row * HEAD_DIM_V;
459
460          if (global_v_row < params.seq_len_kv) {
461              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
462              let base_idx = global_block_idx * F16_PER_BLOCK;
463              let d = V[base_idx]; // scale
464              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
465                  let q_0 = V[base_idx + 1u + block_offset + j];
466                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
467                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
468                  for (var k = 0u; k < 4u; k++) {
469                      let q_byte = get_byte(q_packed, k);
470                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
471                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
472                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
473                      kv_shmem[row_offset + idx] = q_lo;
474                      kv_shmem[row_offset + idx + 16u] = q_hi;
475                  }
476              }
477          }
478      }
479#elif defined(KV_Q8_0)
480      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
481          let blck_idx = elem_idx / BLOCK_SIZE;
482          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
483          let v_row = blck_idx / BLOCKS_V;
484          let global_v_row = kv_tile + v_row;
485          let block_k = blck_idx % BLOCKS_V;
486          let row_offset = v_row * HEAD_DIM_V;
487
488          if (global_v_row < params.seq_len_kv) {
489              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
490              let base_idx = global_block_idx * F16_PER_BLOCK;
491              let d = V[base_idx]; // scale
492              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
493                  let q_0 = V[base_idx + 1u + block_offset + j];
494                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
495                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
496                  for (var k = 0u; k < 4u; k++) {
497                      let q_byte = get_byte_i32(q_packed, k);
498                      let q_val = f16(q_byte) * d;
499                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
500                      kv_shmem[row_offset + idx] = q_val;
501                  }
502              }
503          }
504      }
505#elif defined(KV_DIRECT)
506      // Direct global loads for KV
507#else
508      for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
509          let v_row = elem_idx / HEAD_DIM_V;
510          let v_col = elem_idx % HEAD_DIM_V;
511          let global_v_row = kv_tile + v_row;
512          let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
513          kv_shmem[elem_idx] = f16(select(
514              0.0,
515              V[global_v_row_offset + v_col],
516              global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
517      }
518#endif
519
520      workgroupBarrier();
521
522      // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
523      // we want to compute O += P * V across the full KV tile
524      for (var head_dim_block = subgroup_id * SG_MAT_N;
525           head_dim_block < HEAD_DIM_V;
526           head_dim_block += num_subgroups * SG_MAT_N) {
527              // load O submatrix from shared memory
528              var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
529                  &o_shmem,
530                  head_dim_block,
531                  false,
532                  HEAD_DIM_V
533              );
534              for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
535                  let p_offset = kv_block * SG_MAT_N;
536                  var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
537                      &inter_shmem,
538                      p_offset,
539                      false,
540                      KV_TILE
541                  );
542
543                  // load V submatrix from global or shared memory
544#ifdef KV_DIRECT
545                  let v_block_row = kv_tile + kv_block * SG_MAT_N;
546                  let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
547                  var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
548                      &V,
549                      v_global_offset,
550                      false,
551                      params.stride_v1
552                  );
553#else
554                  let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
555                  var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
556                      &kv_shmem,
557                      v_block_offset + head_dim_block,
558                      false,
559                      HEAD_DIM_V
560                  );
561#endif
562                  // O += P * V
563                  o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
564              }
565              // store O back to shared memory
566              subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
567      }
568      workgroupBarrier();
569    }
570
571#ifdef SINKS
572    // add sinks (applied once after processing all KV tiles)
573    for (var q_tile_row = subgroup_id;
574         q_tile_row < Q_TILE;
575         q_tile_row += num_subgroups) {
576            // no need to process rows beyond seq_len_q
577            let global_q_row = q_row_start + q_tile_row;
578            if (global_q_row >= params.seq_len_q) {
579                break;
580            }
581
582            var prev_max = row_max_shmem[q_tile_row];
583
584            // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
585            let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
586            let new_max = subgroupMax(max(prev_max, sink_val));
587            let max_exp = exp(prev_max - new_max);
588            let sink_exp = exp(sink_val - new_max);
589
590            let sink_exp_sum = subgroupAdd(sink_exp);
591
592            if (sg_inv_id == 0) {
593                exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
594            }
595
596            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
597                let idx = q_tile_row * HEAD_DIM_V + elem_idx;
598                let val = f32(o_shmem[idx]) * max_exp;
599                o_shmem[idx] = f16(val);
600            }
601    }
602    workgroupBarrier();
603#endif
604    for (var q_tile_row = subgroup_id;
605        q_tile_row < Q_TILE;
606        q_tile_row += num_subgroups) {
607
608        let global_q_row = q_row_start + q_tile_row;
609        if (global_q_row >= params.seq_len_q) { break; }
610
611        let exp_sum = exp_sum_shmem[q_tile_row];
612        let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
613
614        let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
615
616        for (var elem_base = sg_inv_id * 4u;
617            elem_base < HEAD_DIM_V;
618            elem_base += subgroup_size * 4u) {
619
620            let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
621            let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
622            let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
623            let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
624
625            let v = vec4<f32>(
626                f32(o_shmem[i0]) * scale,
627                f32(o_shmem[i1]) * scale,
628                f32(o_shmem[i2]) * scale,
629                f32(o_shmem[i3]) * scale
630            );
631
632            let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
633            dst[dst_vec_index] = v;
634        }
635    }
636}