1#version 450
2
3#extension GL_EXT_control_flow_attributes : enable
4#extension GL_EXT_shader_16bit_storage : require
5
6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
7#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
8#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
9#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
10
11#extension GL_KHR_memory_scope_semantics : enable
12#extension GL_KHR_cooperative_matrix : enable
13#extension GL_NV_cooperative_matrix2 : enable
14#extension GL_EXT_buffer_reference : enable
15#extension GL_KHR_shader_subgroup_ballot : enable
16#extension GL_KHR_shader_subgroup_vote : enable
17#extension GL_EXT_null_initializer : enable
18
19#include "types.glsl"
20#include "dequant_funcs_cm2.glsl"
21#include "flash_attn_base.glsl"
22
23layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
24layout (binding = 1) readonly buffer K {uint8_t data_k[];};
25layout (binding = 2) readonly buffer V {uint8_t data_v[];};
26layout (binding = 3) readonly buffer M {uint8_t data_m[];};
27
28ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
29 return max(x, y);
30}
31
32float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
33 return max(x, y);
34}
35
36ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
37 return x;
38}
39
40// Replace matrix elements >= numRows or numCols with 'replace'
41ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
42 if (row >= numRows || col >= numCols) {
43 return replace;
44 }
45 return elem;
46}
47
48ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
49{
50 return exp(elem);
51}
52
53ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
54{
55 return max(elem0, elem1);
56}
57
58#if BLOCK_SIZE > 1
59#define DECODEFUNC , DEQUANTFUNC
60#else
61#define DECODEFUNC
62#endif
63
64// Store the output when doing grouped query attention.
65// Rows index by Q's dimension 2, and the first N rows are valid.
66D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
67{
68 if (r < N && c < HSV) {
69 uint32_t offset = (iq2 + r) * HSV + c;
70 data_o[o_offset + offset] = D_TYPE(elem);
71 }
72 return elem;
73}
74
75void main() {
76#ifdef NEEDS_INIT_IQ_SHMEM
77 init_iq_shmem(gl_WorkGroupSize);
78#endif
79
80 init_indices();
81
82 tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
83 tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
84 tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
85
86 tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
87
88#if BLOCK_SIZE > 1
89 tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
90 tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
91#endif
92
93 tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
94 tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
95 tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
96
97 // hint to the compiler that strides are aligned for the aligned variant of the shader
98 if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
99 {
100 q_stride &= ~7;
101#if BLOCK_SIZE == 1
102 k_stride &= ~7;
103 v_stride &= ~7;
104#endif
105 m_stride &= ~7;
106 }
107 tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
108 tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
109 tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
110
111 coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
112 coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
113
114 uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
115 coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
116
117 Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
118 Qf16 *= float16_t(p.scale);
119
120 coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
121
122 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
123
124 // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
125 const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
126
127 L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
128#if defined(ACC_TYPE_MAX)
129 M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
130#else
131 M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
132#endif
133
134 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
135
136 // ALiBi
137 if (p.max_bias > 0.0f) {
138 coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
139 }
140
141 const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
142 // mo_offset will point to the tile starting at row i*Br and col 0
143 uint32_t mo_offset = mo_stride * i;
144
145 uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
146 if (p.nem2 != 1 || p.nem3 != 1) {
147 m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
148 mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
149 }
150
151 uint32_t mask_opt = 0;
152 uint32_t mask_opt_idx = ~0;
153
154 [[dont_unroll]]
155 for (uint32_t j = start_j; j < end_j; ++j) {
156
157 coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
158 if (MASK_ENABLE) {
159
160 if (USE_MASK_OPT && mask_opt_idx != j / 16) {
161 mask_opt_idx = j / 16;
162 mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
163 }
164 uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
165 if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
166 // skip this block
167 continue;
168 }
169 // Only load if the block is not all zeros
170 if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
171 bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
172
173 if (nem1_bounds_check) {
174 tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
175 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
176 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
177 tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
178
179 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
180 } else {
181 tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
182 // Don't clamp against nem1 when GQA is enabled
183 uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
184 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
185 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
186
187 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
188 }
189 }
190 }
191
192 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
193
194 coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
195
196 uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
197 coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
198 S = coopMatMulAdd(Qf16, K_T, S);
199
200 if (LOGIT_SOFTCAP) {
201 [[unroll]]
202 for (int k = 0; k < S.length(); ++k) {
203 S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
204 }
205 }
206
207 if (MASK_ENABLE) {
208 S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
209 }
210
211 // Clear padding elements to -inf, so they don't contribute to rowmax
212 if (Clamp != 0 &&
213 ((j + 1) * Bc > KV ||
214 (i + 1) * Br > N)) {
215
216 uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
217 uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
218
219 coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
220 }
221
222 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
223
224 coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
225
226 rowmax += coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(FATTN_KQ_MAX_OFFSET);
227
228 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
229
230 // M = max(rowmax, Mold)
231 // P = e^(S - M)
232 // eM = e^(Mold - M)
233 coopMatPerElementNV(M, rowmax, Max, Mold);
234 coopMatPerElementNV(P, S - M, Exp);
235 coopMatPerElementNV(eM, Mold - M, Exp);
236
237 // Clear padding elements to 0, so they don't contribute to rowsum
238 if (Clamp != 0 &&
239 ((j + 1) * Bc > KV ||
240 (i + 1) * Br > N)) {
241
242 uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
243 uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
244
245 coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
246 }
247
248 coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
249
250 // compute rowsum by multiplying by matrix of all ones.
251 coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
252
253 rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
254 rowsum = coopMatMulAdd(P_A, One, rowsum);
255
256 coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
257 uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
258 coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
259
260 L = eM*L + rowsum;
261
262 // This is the "diagonal" matrix in the paper, but since we do componentwise
263 // multiply rather than matrix multiply it has the diagonal element smeared
264 // across the row
265 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
266
267 // resize eM by using smear/reduce
268 coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
269
270 O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
271 O = coopMatMulAdd(P_A, V, O);
272 }
273
274 // If there is split_k, then the split_k resolve shader does the final
275 // division by L. Store the intermediate O value and per-row m and L values.
276 if (p.k_num > 1) {
277 coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
278
279 // note: O and Q have swapped coord 1,2.
280 uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
281 coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
282
283 o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
284 coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
285 coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
286 return;
287 }
288
289 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
290
291 // resize L by using smear/reduce
292 coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
293
294 if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
295 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
296 coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
297
298 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
299
300 // resize M by using smear/reduce
301 coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
302
303 // O, Ldiag, Mr all have the same type so all element locations match
304 [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
305 ACC_TYPE sink = S[i];
306
307 ACC_TYPE ms = ACC_TYPE(1.0f);
308 ACC_TYPE vs = ACC_TYPE(1.0f);
309
310 if (sink > Mr[i]) {
311 ms = exp(Mr[i] - sink);
312
313 O[i] *= float16_t(ms);
314 } else {
315 vs = exp(sink - Mr[i]);
316 }
317
318 Ldiag[i] = Ldiag[i]*ms + vs;
319 }
320 }
321
322 [[unroll]]
323 for (int k = 0; k < Ldiag.length(); ++k) {
324 Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
325 }
326
327 coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
328
329 O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(Ldiag)*O_D;
330
331#if defined(ACC_TYPE_MAX)
332 [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }
333#endif
334
335 uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
336
337 if (p.gqa_ratio > 1) {
338 coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
339 } else {
340 tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
341 tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
342
343 // permute dimensions
344 tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
345
346 coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
347 }
348}