1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  7#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  8
  9#extension GL_KHR_shader_subgroup_shuffle : enable
 10#extension GL_KHR_shader_subgroup_vote : enable
 11
 12#include "types.glsl"
 13#include "flash_attn_base.glsl"
 14
 15const uint32_t HSK_per_thread = HSK / D_split;
 16const uint32_t HSV_per_thread = HSV / D_split;
 17
 18const uint32_t cols_per_iter = WorkGroupSize / D_split;
 19const uint32_t cols_per_thread = Bc / cols_per_iter;
 20
 21
 22layout (binding = 0) readonly buffer Q {float data_q[];};
 23layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
 24layout (binding = 1) readonly buffer K {float16_t data_k[];};
 25layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
 26layout (binding = 2) readonly buffer V {float16_t data_v[];};
 27layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 28layout (binding = 3) readonly buffer M {float16_t data_m[];};
 29
 30// Store the output when doing grouped query attention.
 31// Rows index by Q's dimension 2, and the first N rows are valid.
 32D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 33{
 34    uint32_t offset = (iq2 + r) * HSV + c;
 35    data_o[o_offset + offset] = D_TYPE(elem);
 36    return elem;
 37}
 38
 39shared FLOAT_TYPE tmpsh[WorkGroupSize];
 40shared vec4 tmpshv4[WorkGroupSize];
 41
 42shared float masksh[Bc][Br];
 43shared vec4 Qf[Br][HSK / 4];
 44
 45void main() {
 46#ifdef NEEDS_INIT_IQ_SHMEM
 47    init_iq_shmem(gl_WorkGroupSize);
 48#endif
 49
 50    init_indices();
 51
 52    const uint32_t tid = gl_LocalInvocationIndex;
 53    const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
 54    const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
 55
 56    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
 57
 58    [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
 59        uint32_t d = (idx + tid) % (HSK / 4);
 60        uint32_t r = (idx + tid) / (HSK / 4);
 61        if (r < Br && d < HSK / 4 &&
 62            i * Br + r < N) {
 63            Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
 64        }
 65    }
 66    barrier();
 67
 68    vec4 Of[Br][HSV_per_thread / 4];
 69    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
 70        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
 71            Of[r][d] = vec4(0.0);
 72        }
 73    }
 74
 75    float Lf[Br], Mf[Br];
 76
 77    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
 78    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
 79
 80    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
 81        Lf[r] = 0;
 82        Mf[r] = NEG_FLT_MAX_OVER_2;
 83    }
 84
 85    float slope[Br];
 86    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
 87        slope[r] = 1.0;
 88    }
 89
 90    // ALiBi
 91    if (p.max_bias > 0.0f) {
 92        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
 93            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
 94        }
 95    }
 96
 97    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
 98    // mo_offset will point to the tile starting at row i*Br and col 0
 99    uint32_t mo_offset = mo_stride * i;
100
101#if BLOCK_SIZE > 1
102    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
103    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
104#else
105    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
106    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
107#endif
108    uint32_t m_offset = gqa_iq1*KV;
109    if (p.nem2 != 1 || p.nem3 != 1) {
110        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
111        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
112    }
113
114    uint32_t mask_opt = 0;
115    uint32_t mask_opt_idx = ~0;
116
117    [[dont_unroll]]
118    for (uint32_t j = start_j; j < end_j; ++j) {
119
120        if (USE_MASK_OPT && mask_opt_idx != j / 16) {
121            mask_opt_idx = j / 16;
122            mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
123        }
124        uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
125        if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
126            // skip this block
127            continue;
128        }
129        // Only load if the block is not all zeros
130        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
131            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
132
133            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
134                uint32_t c = (idx + tid) % Bc;
135                uint32_t r = (idx + tid) / Bc;
136                if (idx + tid < Bc * Br) {
137                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
138                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
139                        masksh[c][r] = m;
140                    } else {
141                        masksh[c][r] = float(0);
142                    }
143                }
144            }
145            barrier();
146        }
147
148        float Sf[Br][cols_per_thread];
149        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
150            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
151                Sf[r][c] = 0.0;
152            }
153        }
154
155
156        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
157            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
158                continue;
159            }
160            [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
161#if BLOCK_SIZE > 1
162                uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
163                uint ib = coord / BLOCK_SIZE;
164                uint iqs = (coord % BLOCK_SIZE);
165                vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
166#else
167                vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
168#endif
169                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
170                    Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
171                }
172            }
173        }
174
175        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
176            // Compute sum across the D_split
177            [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
178                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
179                    Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
180                }
181            }
182        }
183
184        if (LOGIT_SOFTCAP) {
185            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
186                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
187                    Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
188                }
189            }
190        }
191
192        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
193            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
194                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
195                    float mvf = masksh[c * cols_per_iter + col_tid][r];
196
197                    Sf[r][c] += slope[r]*mvf;
198                }
199            }
200            barrier();
201        }
202
203        float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
204        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
205            rowmaxf[r] = NEG_FLT_MAX_OVER_2;
206            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
207                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
208                    continue;
209                }
210                rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
211            }
212            Moldf[r] = Mf[r];
213
214            // M = max(rowmax, Mold)
215            // P = e^(S - M)
216            // eM = e^(Mold - M)
217            Mf[r] = max(rowmaxf[r], Moldf[r]);
218            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
219                Pf[r][c] = exp(Sf[r][c] - Mf[r]);
220            }
221            eMf[r] = exp(Moldf[r] - Mf[r]);
222
223            // Compute sum across row of P
224            rowsumf[r] = 0.0;
225            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
226                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
227                    continue;
228                }
229                rowsumf[r] += Pf[r][c];
230            }
231
232            Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
233        }
234
235        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
236            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
237                Of[r][d] = eMf[r] * Of[r][d];
238            }
239        }
240
241        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
242            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
243                continue;
244            }
245            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
246#if BLOCK_SIZE > 1
247                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
248                uint ib = coord / BLOCK_SIZE;
249                uint iqs = (coord % BLOCK_SIZE);
250                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
251#else
252                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
253#endif
254                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
255                    Of[r][d] += Pf[r][c] * Vf;
256                }
257            }
258        }
259
260        barrier();
261    }
262
263    // reduce across threads
264
265    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
266        float rowmaxf, eMf;
267
268        tmpsh[tid] = Mf[r];
269        // Compute max across the row
270        barrier();
271        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
272            if (tid < s) {
273                tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
274            }
275            barrier();
276        }
277        rowmaxf = tmpsh[d_tid];
278        barrier();
279
280        float Moldf = Mf[r];
281
282        // M = max(rowmax, Mold)
283        // eM = e^(Mold - M)
284        Mf[r] = max(rowmaxf, Moldf);
285        eMf = exp(Moldf - Mf[r]);
286
287        Lf[r] = eMf*Lf[r];
288
289        tmpsh[tid] = Lf[r];
290
291        // Compute sum across the row
292        barrier();
293        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
294            if (tid < s) {
295                tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
296            }
297            barrier();
298        }
299        Lf[r] = tmpsh[d_tid];
300        barrier();
301
302        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
303
304            Of[r][d] = eMf * Of[r][d];
305            tmpshv4[tid] = Of[r][d];
306
307            barrier();
308            [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
309                if (tid < s) {
310                    Of[r][d] += tmpshv4[tid + s];
311                    tmpshv4[tid] = Of[r][d];
312                }
313                barrier();
314            }
315            Of[r][d] = tmpshv4[d_tid];
316            barrier();
317        }
318    }
319
320
321    // If there is split_k, then the split_k resolve shader does the final
322    // division by L. Store the intermediate O value and per-row m and L values.
323    if (p.k_num > 1) {
324        // note: O and Q have swapped coord 1,2.
325        uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
326
327        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
328            if (r < N) {
329                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
330                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
331                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
332                    }
333                }
334            }
335        }
336
337        o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
338        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
339            if (r < N) {
340                perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
341                perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
342            }
343        }
344
345        return;
346    }
347
348    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
349        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
350            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
351
352            float ms = 1.0f;
353            float vs = 1.0f;
354
355            if (sink > Mf[r]) {
356                ms = exp(Mf[r] - sink);
357
358                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
359                    Of[r][d] *= ms;
360                }
361            } else {
362                vs = exp(sink - Mf[r]);
363            }
364
365            Lf[r] = Lf[r]*ms + vs;
366        }
367    }
368
369    float Lfrcp[Br];
370    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
371        Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
372    }
373
374    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
375        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
376            Of[r][d] *= Lfrcp[r];
377#if defined(ACC_TYPE_MAX)
378            Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
379#endif
380        }
381    }
382
383    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
384
385    if (p.gqa_ratio > 1) {
386        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
387            if (r < N) {
388                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
389                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
390                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
391                    }
392                }
393            }
394        }
395    } else {
396        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
397            if (i * Br + r < N) {
398                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
399                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
400                        data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
401                    }
402                }
403            }
404        }
405    }
406}