1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  7#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  8
  9#extension GL_KHR_shader_subgroup_basic : enable
 10#extension GL_KHR_shader_subgroup_arithmetic : enable
 11#extension GL_KHR_shader_subgroup_vote : enable
 12#extension GL_KHR_memory_scope_semantics : enable
 13#extension GL_KHR_cooperative_matrix : enable
 14
 15#include "types.glsl"
 16#include "flash_attn_base.glsl"
 17
 18// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
 19const uint32_t MatBr = 16;
 20const uint32_t MatBc = 16;
 21
 22const uint32_t row_split = Bc / MatBc;
 23const uint32_t rows_per_thread = Br / row_split;
 24const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
 25const uint32_t cols_per_thread = Bc / cols_per_iter;
 26
 27
 28layout (binding = 0) readonly buffer Q {float data_q[];};
 29layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
 30layout (binding = 1) readonly buffer K {float16_t data_k[];};
 31layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
 32layout (binding = 2) readonly buffer V {float16_t data_v[];};
 33layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 34layout (binding = 3) readonly buffer M {float16_t data_m[];};
 35
 36// Store the output when doing grouped query attention.
 37// Rows index by Q's dimension 2, and the first N rows are valid.
 38D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 39{
 40    uint32_t offset = (iq2 + r) * HSV + c;
 41    data_o[o_offset + offset] = D_TYPE(elem);
 42    return elem;
 43}
 44
 45const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
 46shared f16vec4 Qf[Br * qstride];
 47
 48const uint psh_stride = Br / 4 + 2;
 49shared f16vec4 Psh[Bc * psh_stride];
 50
 51// Avoid padding for hsk==256 to make it fit in 48KB shmem.
 52const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
 53shared ACC_TYPEV4 sfsh[Bc * sfshstride];
 54
 55const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
 56const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
 57const uint vsh_stride = v_cols;
 58shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
 59
 60shared ACC_TYPE slope[Br];
 61
 62void main() {
 63#ifdef NEEDS_INIT_IQ_SHMEM
 64    init_iq_shmem(gl_WorkGroupSize);
 65#endif
 66
 67    init_indices();
 68
 69    const uint32_t tid = gl_LocalInvocationIndex;
 70
 71    const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
 72    const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
 73    const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
 74    const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
 75
 76#define tile_row(r) (row_tid * rows_per_thread + (r))
 77
 78    // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
 79    if ((HSK % 16) != 0) {
 80        [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
 81            if (i + tid < Br * qstride) {
 82                Qf[i + tid] = f16vec4(0);
 83            }
 84        }
 85        [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
 86            if (i + tid < Bc * kshstride) {
 87                ksh[i + tid] = f16vec4(0);
 88            }
 89        }
 90        barrier();
 91    }
 92
 93    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
 94
 95    [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
 96        uint32_t d = (idx + tid) % (HSK / 4);
 97        uint32_t r = (idx + tid) / (HSK / 4);
 98        if (r < Br && d < HSK / 4 &&
 99            i * Br + r < N) {
100            Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
101        }
102    }
103    barrier();
104
105    ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
106    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
107        [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
108            Of[r][d] = ACC_TYPEV4(0.0);
109        }
110    }
111
112    float Lf[rows_per_thread], Mf[rows_per_thread];
113
114    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
115    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
116
117    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
118        Lf[r] = 0;
119        Mf[r] = NEG_FLT_MAX_OVER_2;
120    }
121
122    // ALiBi
123    if (p.max_bias > 0.0f) {
124        if (tid < Br) {
125            uint r = tid;
126            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
127        }
128    } else {
129        if (tid < Br) {
130            uint r = tid;
131            slope[r] = ACC_TYPE(1.0);
132        }
133    }
134
135    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
136    // mo_offset will point to the tile starting at row i*Br and col 0
137    uint32_t mo_offset = mo_stride * i;
138
139#if BLOCK_SIZE > 1
140    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
141    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
142#else
143    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
144    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
145#endif
146    uint32_t m_offset = gqa_iq1*KV;
147    if (p.nem2 != 1 || p.nem3 != 1) {
148        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
149        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
150    }
151
152    uint32_t mask_opt = 0;
153    uint32_t mask_opt_idx = ~0;
154
155    [[dont_unroll]]
156    for (uint32_t j = start_j; j < end_j; ++j) {
157
158        f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
159        [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
160            mask_cache[idx] = f16vec4(0);
161        }
162
163        if (MASK_ENABLE) {
164
165            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
166                mask_opt_idx = j / 16;
167                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
168            }
169            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
170            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
171                // skip this block
172                continue;
173            }
174            // Only load if the block is not all zeros
175            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
176                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
177
178                float max_mask = NEG_FLT_MAX_OVER_2;
179                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
180                    uint32_t c = (idx + tid) / (Br / 4);
181                    uint32_t r = (idx + tid) % (Br / 4);
182                    if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
183                        if ((!KV_bounds_check || j * Bc + c < KV)) {
184                            f16vec4 m;
185                            if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
186                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
187                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
188                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
189                                            data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
190                                max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
191                            } else if (i * Br + r * 4 + 2 < p.nem1) {
192                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
193                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
194                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
195                                            0.0);
196                                max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
197                            } else if (i * Br + r * 4 + 1 < p.nem1) {
198                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
199                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
200                                            0.0,
201                                            0.0);
202                                max_mask = max(max(max_mask, float(m[0])), float(m[1]));
203                            } else if (i * Br + r * 4 < p.nem1) {
204                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
205                                            0.0,
206                                            0.0,
207                                            0.0);
208                                max_mask = max(max_mask, float(m[0]));
209                            } else {
210                                m = f16vec4(0.0);
211                            }
212                            mask_cache[idx / WorkGroupSize] = m;
213                        }
214                    }
215                }
216            }
217        }
218
219        if (K_LOAD_SHMEM != 0) {
220            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
221                uint32_t d = (idx + tid) % (HSK / 4);
222                uint32_t c = (idx + tid) / (HSK / 4);
223                if (c < Bc && d < HSK / 4) {
224                    f16vec4 K_Tf = f16vec4(0);
225                    if (!KV_bounds_check || j * Bc + c < KV) {
226#if BLOCK_SIZE > 1
227                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
228                        uint ib = coord / BLOCK_SIZE;
229                        uint iqs = (coord % BLOCK_SIZE);
230                        K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
231#else
232                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
233#endif
234                    }
235
236                    ksh[c * kshstride + d] = K_Tf;
237                }
238            }
239            barrier();
240        }
241
242        // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
243        // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
244        // This is written transposed in order to allow for N being 8 if implementations need it
245        coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
246        coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
247        coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
248
249        [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
250            if (K_LOAD_SHMEM == 0) {
251#if BLOCK_SIZE == 1
252            if (KV_bounds_check || d * 16 + 16 > HSK) {
253#endif
254            barrier();
255            [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
256                uint32_t col_vec = (idx + tid) % (MatBr / 4);
257                uint32_t row = (idx + tid) / (MatBr / 4);
258                if (idx + tid < Bc * MatBr / 4) {
259                    f16vec4 K_Tf = f16vec4(0);
260                    if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
261#if BLOCK_SIZE > 1
262                        uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
263                        uint ib = coord / BLOCK_SIZE;
264                        uint iqs = (coord % BLOCK_SIZE);
265                        K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
266#else
267                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
268#endif
269                    }
270
271                    ksh[row * kshstride + col_vec] = K_Tf;
272                }
273            }
274            barrier();
275#if BLOCK_SIZE == 1
276            }
277#endif
278
279#if BLOCK_SIZE == 1
280            if (KV_bounds_check || d * 16 + 16 > HSK)
281#endif
282            {
283                uint coord = (gl_SubgroupID * MatBc) * kshstride;
284                coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
285            }
286#if BLOCK_SIZE == 1
287            else {
288                const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
289                coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
290            }
291#endif
292            } else {
293                uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
294                coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
295            }
296
297            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
298
299            SfMat = coopMatMulAdd(KMat, QMat, SfMat);
300        }
301
302        uint coord = gl_SubgroupID * MatBc * sfshstride;
303        coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
304        barrier();
305
306        if (LOGIT_SOFTCAP) {
307            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
308                uint32_t c = (idx + tid) / (Br / 4);
309                uint32_t r = (idx + tid) % (Br / 4);
310                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
311                    sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
312                }
313            }
314            barrier();
315        }
316
317        if (MASK_ENABLE) {
318            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
319                uint32_t c = (idx + tid) / (Br / 4);
320                uint32_t r = (idx + tid) % (Br / 4);
321                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
322                    if (!KV_bounds_check || j * Bc + c < KV) {
323                        // Mask nem1 bounds check is handled when loading masks
324                        ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
325                        ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
326                        sfsh[c * sfshstride + r] += slopes * masks;
327                    }
328                }
329            }
330            barrier();
331        }
332
333        float eMf[rows_per_thread];
334        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
335            const uint r_vec  = tile_row(r) / 4;
336            const uint r_comp = tile_row(r) % 4;
337
338            float rowmaxf = NEG_FLT_MAX_OVER_2;
339            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
340                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
341                    continue;
342                }
343                rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
344            }
345            float Moldf = Mf[r];
346
347            // Compute max across the row
348            rowmaxf = subgroupMax(rowmaxf);
349
350            // M = max(rowmax, Mold)
351            // P = e^(S - M)
352            // eM = e^(Mold - M)
353            Mf[r] = max(rowmaxf, Moldf);
354            eMf[r] = exp(Moldf - Mf[r]);
355
356            Lf[r] = eMf[r]*Lf[r];
357        }
358
359        [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
360            const uint d_local = d0 / threads_per_rowgroup;
361            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
362                Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
363            }
364        }
365
366        // Calculate and store Pf in Psh
367        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
368            const uint col = c * cols_per_iter + col_tid;
369
370            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
371                const uint row = tile_row(r);
372                if (KV_bounds_check && j * Bc + col >= KV) {
373                    Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
374                } else {
375                    const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
376                    const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
377                    [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
378                        Lf[r + vec_idx] += Pf[vec_idx];
379                    }
380                    Psh[col * psh_stride + row / 4] = Pf;
381                }
382            }
383        }
384
385        const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
386
387        // Each subgroup handles HSV/4 columns
388        [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
389            const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
390
391            SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
392
393            // Preload V tiles for [Bc, 16 * num subgroups]
394            const uint v_rows = Bc;
395            const uint v_total = v_rows * v_cols;
396            const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
397
398#if BLOCK_SIZE == 1
399            // For f16, only preload if not aligned
400            if (KV_bounds_check) {
401#endif
402            [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
403                const uint idx = i * gl_WorkGroupSize.x + tid;
404                const uint row = idx / v_cols;
405                const uint col = idx % v_cols;
406
407                const uint v_row = j * Bc + row;
408                const uint v_col = hsv_tile * MatBc * row_split + col * 4;
409
410                const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
411                const uint ib = coord / BLOCK_SIZE;
412                const uint iqs = coord % BLOCK_SIZE;
413
414                if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
415#if BLOCK_SIZE > 1
416                    ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
417#else
418                    ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
419#endif
420                } else {
421                    ksh[row * vsh_stride + col] = f16vec4(0.0f);
422                }
423            }
424#if BLOCK_SIZE == 1
425            }
426#endif
427
428            barrier();
429
430            [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
431                coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
432
433#if BLOCK_SIZE == 1
434                if (!KV_bounds_check) {
435                    // F16 values can be loaded directly from global memory
436                    const uint v_tile_row = j * Bc + bc_chunk * MatBc;
437                    const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
438                    coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
439                } else
440#endif
441                {
442                    const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
443                    coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
444                }
445
446                SfMat = coopMatMulAdd(KMat, QMat, SfMat);
447            }
448
449            // Store SfMat to sfsh and load into Of
450            const uint osh_stride = row_split * MatBc / 4;
451            const uint o_offset = gl_SubgroupID * MatBc / 4;
452            coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
453
454            barrier();
455
456            const uint hsv_per_tile = row_split * MatBc;
457            const uint hsv_base = hsv_tile * hsv_per_tile;
458            const uint d_values_per_tile = hsv_per_tile / 4;
459
460            const uint d_start = hsv_tile * d_values_per_tile;
461            const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
462
463            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
464                const uint row = tile_row(r);
465
466                [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
467                    const uint d = d_local * threads_per_rowgroup + col_tid;
468                    const uint hsv_col = 4 * d;
469
470                    if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
471                        const uint local_hsv = (hsv_col - hsv_base) / 4;
472                        Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
473                    }
474                }
475            }
476        }
477
478        barrier();
479    }
480
481    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
482        Lf[r] = subgroupAdd(Lf[r]);
483    }
484
485    // If there is split_k, then the split_k resolve shader does the final
486    // division by L. Store the intermediate O value and per-row m and L values.
487    if (p.k_num > 1) {
488        // note: O and Q have swapped coord 1,2.
489        uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
490
491        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
492            if (tile_row(r) < N) {
493                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
494                    const uint d = d0 + col_tid;
495                    if (d >= HSV/4) break;
496                    const uint d_local = d0 / threads_per_rowgroup;
497                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
498                        perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
499                    }
500                }
501            }
502        }
503
504        o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
505        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
506            if (tile_row(r) < N) {
507                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
508                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
509            }
510        }
511
512        return;
513    }
514
515    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
516        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
517            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
518
519            float ms = 1.0f;
520            float vs = 1.0f;
521
522            if (sink > Mf[r]) {
523                ms = exp(Mf[r] - sink);
524
525                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
526                    const uint d_local = d0 / threads_per_rowgroup;
527                    Of[r][d_local] *= ACC_TYPE(ms);
528                }
529            } else {
530                vs = exp(sink - Mf[r]);
531            }
532
533            Lf[r] = Lf[r]*ms + vs;
534        }
535    }
536
537    float Lfrcp[rows_per_thread];
538    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
539        Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
540    }
541
542    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
543        const uint d_local = d0 / threads_per_rowgroup;
544        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
545            Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
546#if defined(ACC_TYPE_MAX)
547            Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
548#endif
549        }
550    }
551
552    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
553
554    if (p.gqa_ratio > 1) {
555        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
556            if (tile_row(r) < N) {
557                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
558                    const uint d = d0 + col_tid;
559                    if (d >= HSV / 4) break;
560                    const uint d_local = d0 / threads_per_rowgroup;
561                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
562                        perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
563                    }
564                }
565            }
566        }
567    } else {
568        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
569            if (i * Br + tile_row(r) < N) {
570                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
571                    const uint d = d0 + col_tid;
572                    if (d >= HSV / 4) break;
573                    const uint d_local = d0 / threads_per_rowgroup;
574                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
575                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
576                    }
577                }
578            }
579        }
580    }
581}