1#version 450
2
3#extension GL_EXT_control_flow_attributes : enable
4#extension GL_EXT_shader_16bit_storage : require
5
6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
7#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
8
9#extension GL_KHR_shader_subgroup_shuffle : enable
10#extension GL_KHR_shader_subgroup_vote : enable
11
12#include "types.glsl"
13#include "flash_attn_base.glsl"
14
15const uint32_t HSK_per_thread = HSK / D_split;
16const uint32_t HSV_per_thread = HSV / D_split;
17
18const uint32_t cols_per_iter = WorkGroupSize / D_split;
19const uint32_t cols_per_thread = Bc / cols_per_iter;
20
21
22layout (binding = 0) readonly buffer Q {float data_q[];};
23layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
24layout (binding = 1) readonly buffer K {float16_t data_k[];};
25layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
26layout (binding = 2) readonly buffer V {float16_t data_v[];};
27layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
28layout (binding = 3) readonly buffer M {float16_t data_m[];};
29
30// Store the output when doing grouped query attention.
31// Rows index by Q's dimension 2, and the first N rows are valid.
32D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
33{
34 uint32_t offset = (iq2 + r) * HSV + c;
35 data_o[o_offset + offset] = D_TYPE(elem);
36 return elem;
37}
38
39shared FLOAT_TYPE tmpsh[WorkGroupSize];
40shared vec4 tmpshv4[WorkGroupSize];
41
42shared float masksh[Bc][Br];
43shared vec4 Qf[Br][HSK / 4];
44
45void main() {
46#ifdef NEEDS_INIT_IQ_SHMEM
47 init_iq_shmem(gl_WorkGroupSize);
48#endif
49
50 init_indices();
51
52 const uint32_t tid = gl_LocalInvocationIndex;
53 const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
54 const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
55
56 uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
57
58 [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
59 uint32_t d = (idx + tid) % (HSK / 4);
60 uint32_t r = (idx + tid) / (HSK / 4);
61 if (r < Br && d < HSK / 4 &&
62 i * Br + r < N) {
63 Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
64 }
65 }
66 barrier();
67
68 vec4 Of[Br][HSV_per_thread / 4];
69 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
70 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
71 Of[r][d] = vec4(0.0);
72 }
73 }
74
75 float Lf[Br], Mf[Br];
76
77 // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
78 const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
79
80 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
81 Lf[r] = 0;
82 Mf[r] = NEG_FLT_MAX_OVER_2;
83 }
84
85 float slope[Br];
86 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
87 slope[r] = 1.0;
88 }
89
90 // ALiBi
91 if (p.max_bias > 0.0f) {
92 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
93 slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
94 }
95 }
96
97 const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
98 // mo_offset will point to the tile starting at row i*Br and col 0
99 uint32_t mo_offset = mo_stride * i;
100
101#if BLOCK_SIZE > 1
102 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
103 uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
104#else
105 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
106 uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
107#endif
108 uint32_t m_offset = gqa_iq1*KV;
109 if (p.nem2 != 1 || p.nem3 != 1) {
110 m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
111 mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
112 }
113
114 uint32_t mask_opt = 0;
115 uint32_t mask_opt_idx = ~0;
116
117 [[dont_unroll]]
118 for (uint32_t j = start_j; j < end_j; ++j) {
119
120 if (USE_MASK_OPT && mask_opt_idx != j / 16) {
121 mask_opt_idx = j / 16;
122 mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
123 }
124 uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
125 if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
126 // skip this block
127 continue;
128 }
129 // Only load if the block is not all zeros
130 if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
131 bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
132
133 [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
134 uint32_t c = (idx + tid) % Bc;
135 uint32_t r = (idx + tid) / Bc;
136 if (idx + tid < Bc * Br) {
137 if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
138 float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
139 masksh[c][r] = m;
140 } else {
141 masksh[c][r] = float(0);
142 }
143 }
144 }
145 barrier();
146 }
147
148 float Sf[Br][cols_per_thread];
149 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
150 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
151 Sf[r][c] = 0.0;
152 }
153 }
154
155
156 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
157 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
158 continue;
159 }
160 [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
161#if BLOCK_SIZE > 1
162 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
163 uint ib = coord / BLOCK_SIZE;
164 uint iqs = (coord % BLOCK_SIZE);
165 vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
166#else
167 vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
168#endif
169 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
170 Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
171 }
172 }
173 }
174
175 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
176 // Compute sum across the D_split
177 [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
178 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
179 Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
180 }
181 }
182 }
183
184 if (LOGIT_SOFTCAP) {
185 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
186 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
187 Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
188 }
189 }
190 }
191
192 if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
193 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
194 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
195 float mvf = masksh[c * cols_per_iter + col_tid][r];
196
197 Sf[r][c] += slope[r]*mvf;
198 }
199 }
200 barrier();
201 }
202
203 float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
204 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
205 rowmaxf[r] = NEG_FLT_MAX_OVER_2;
206 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
207 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
208 continue;
209 }
210 rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
211 }
212 Moldf[r] = Mf[r];
213
214 // M = max(rowmax, Mold)
215 // P = e^(S - M)
216 // eM = e^(Mold - M)
217 Mf[r] = max(rowmaxf[r], Moldf[r]);
218 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
219 Pf[r][c] = exp(Sf[r][c] - Mf[r]);
220 }
221 eMf[r] = exp(Moldf[r] - Mf[r]);
222
223 // Compute sum across row of P
224 rowsumf[r] = 0.0;
225 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
226 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
227 continue;
228 }
229 rowsumf[r] += Pf[r][c];
230 }
231
232 Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
233 }
234
235 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
236 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
237 Of[r][d] = eMf[r] * Of[r][d];
238 }
239 }
240
241 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
242 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
243 continue;
244 }
245 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
246#if BLOCK_SIZE > 1
247 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
248 uint ib = coord / BLOCK_SIZE;
249 uint iqs = (coord % BLOCK_SIZE);
250 vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
251#else
252 vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
253#endif
254 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
255 Of[r][d] += Pf[r][c] * Vf;
256 }
257 }
258 }
259
260 barrier();
261 }
262
263 // reduce across threads
264
265 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
266 float rowmaxf, eMf;
267
268 tmpsh[tid] = Mf[r];
269 // Compute max across the row
270 barrier();
271 [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
272 if (tid < s) {
273 tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
274 }
275 barrier();
276 }
277 rowmaxf = tmpsh[d_tid];
278 barrier();
279
280 float Moldf = Mf[r];
281
282 // M = max(rowmax, Mold)
283 // eM = e^(Mold - M)
284 Mf[r] = max(rowmaxf, Moldf);
285 eMf = exp(Moldf - Mf[r]);
286
287 Lf[r] = eMf*Lf[r];
288
289 tmpsh[tid] = Lf[r];
290
291 // Compute sum across the row
292 barrier();
293 [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
294 if (tid < s) {
295 tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
296 }
297 barrier();
298 }
299 Lf[r] = tmpsh[d_tid];
300 barrier();
301
302 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
303
304 Of[r][d] = eMf * Of[r][d];
305 tmpshv4[tid] = Of[r][d];
306
307 barrier();
308 [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
309 if (tid < s) {
310 Of[r][d] += tmpshv4[tid + s];
311 tmpshv4[tid] = Of[r][d];
312 }
313 barrier();
314 }
315 Of[r][d] = tmpshv4[d_tid];
316 barrier();
317 }
318 }
319
320
321 // If there is split_k, then the split_k resolve shader does the final
322 // division by L. Store the intermediate O value and per-row m and L values.
323 if (p.k_num > 1) {
324 // note: O and Q have swapped coord 1,2.
325 uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
326
327 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
328 if (r < N) {
329 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
330 [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
331 perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
332 }
333 }
334 }
335 }
336
337 o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
338 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
339 if (r < N) {
340 perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
341 perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
342 }
343 }
344
345 return;
346 }
347
348 if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
349 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
350 float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
351
352 float ms = 1.0f;
353 float vs = 1.0f;
354
355 if (sink > Mf[r]) {
356 ms = exp(Mf[r] - sink);
357
358 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
359 Of[r][d] *= ms;
360 }
361 } else {
362 vs = exp(sink - Mf[r]);
363 }
364
365 Lf[r] = Lf[r]*ms + vs;
366 }
367 }
368
369 float Lfrcp[Br];
370 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
371 Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
372 }
373
374 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
375 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
376 Of[r][d] *= Lfrcp[r];
377#if defined(ACC_TYPE_MAX)
378 Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
379#endif
380 }
381 }
382
383 uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
384
385 if (p.gqa_ratio > 1) {
386 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
387 if (r < N) {
388 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
389 [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
390 perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
391 }
392 }
393 }
394 }
395 } else {
396 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
397 if (i * Br + r < N) {
398 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
399 [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
400 data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
401 }
402 }
403 }
404 }
405 }
406}