1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  7#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  8#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  9#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 10
 11#extension GL_KHR_memory_scope_semantics : enable
 12#extension GL_KHR_cooperative_matrix : enable
 13#extension GL_NV_cooperative_matrix2 : enable
 14#extension GL_EXT_buffer_reference : enable
 15#extension GL_KHR_shader_subgroup_ballot : enable
 16#extension GL_KHR_shader_subgroup_vote : enable
 17#extension GL_EXT_null_initializer : enable
 18
 19#include "types.glsl"
 20#include "dequant_funcs_cm2.glsl"
 21#include "flash_attn_base.glsl"
 22
 23layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
 24layout (binding = 1) readonly buffer K {uint8_t data_k[];};
 25layout (binding = 2) readonly buffer V {uint8_t data_v[];};
 26layout (binding = 3) readonly buffer M {uint8_t data_m[];};
 27
 28ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
 29    return max(x, y);
 30}
 31
 32float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
 33    return max(x, y);
 34}
 35
 36ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
 37    return x;
 38}
 39
 40// Replace matrix elements >= numRows or numCols with 'replace'
 41ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
 42    if (row >= numRows || col >= numCols) {
 43        return replace;
 44    }
 45    return elem;
 46}
 47
 48ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
 49{
 50    return exp(elem);
 51}
 52
 53ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
 54{
 55    return max(elem0, elem1);
 56}
 57
 58#if BLOCK_SIZE > 1
 59#define DECODEFUNC , DEQUANTFUNC
 60#else
 61#define DECODEFUNC
 62#endif
 63
 64// Store the output when doing grouped query attention.
 65// Rows index by Q's dimension 2, and the first N rows are valid.
 66D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 67{
 68    if (r < N && c < HSV) {
 69        uint32_t offset = (iq2 + r) * HSV + c;
 70        data_o[o_offset + offset] = D_TYPE(elem);
 71    }
 72    return elem;
 73}
 74
 75void main() {
 76#ifdef NEEDS_INIT_IQ_SHMEM
 77    init_iq_shmem(gl_WorkGroupSize);
 78#endif
 79
 80    init_indices();
 81
 82    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
 83    tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
 84    tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
 85
 86    tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
 87
 88#if BLOCK_SIZE > 1
 89    tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
 90    tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
 91#endif
 92
 93    tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
 94    tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
 95    tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
 96
 97    // hint to the compiler that strides are aligned for the aligned variant of the shader
 98    if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
 99    {
100        q_stride &= ~7;
101#if BLOCK_SIZE == 1
102        k_stride &= ~7;
103        v_stride &= ~7;
104#endif
105        m_stride &= ~7;
106    }
107    tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
108    tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
109    tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
110
111    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
112    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
113
114    uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
115    coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
116
117    Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
118    Qf16 *= float16_t(p.scale);
119
120    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
121
122    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
123
124    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
125    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
126
127    L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
128#if defined(ACC_TYPE_MAX)
129    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
130#else
131    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
132#endif
133
134    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
135
136    // ALiBi
137    if (p.max_bias > 0.0f) {
138        coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
139    }
140
141    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
142    // mo_offset will point to the tile starting at row i*Br and col 0
143    uint32_t mo_offset = mo_stride * i;
144
145    uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
146    if (p.nem2 != 1 || p.nem3 != 1) {
147        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
148        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
149    }
150
151    uint32_t mask_opt = 0;
152    uint32_t mask_opt_idx = ~0;
153
154    [[dont_unroll]]
155    for (uint32_t j = start_j; j < end_j; ++j) {
156
157        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
158        if (MASK_ENABLE) {
159
160            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
161                mask_opt_idx = j / 16;
162                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
163            }
164            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
165            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
166                // skip this block
167                continue;
168            }
169            // Only load if the block is not all zeros
170            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
171                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
172
173                if (nem1_bounds_check) {
174                    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
175                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
176                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
177                    tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
178
179                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
180                } else {
181                    tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
182                    // Don't clamp against nem1 when GQA is enabled
183                    uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
184                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
185                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
186
187                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
188                }
189            }
190        }
191
192        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
193
194        coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
195
196        uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
197        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
198        S = coopMatMulAdd(Qf16, K_T, S);
199
200        if (LOGIT_SOFTCAP) {
201            [[unroll]]
202            for (int k = 0; k < S.length(); ++k) {
203                S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
204            }
205        }
206
207        if (MASK_ENABLE) {
208            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
209        }
210
211        // Clear padding elements to -inf, so they don't contribute to rowmax
212        if (Clamp != 0 &&
213            ((j + 1) * Bc > KV ||
214             (i + 1) * Br > N)) {
215
216            uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;
217            uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
218
219            coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
220        }
221
222        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
223
224        coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
225
226        rowmax += coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(FATTN_KQ_MAX_OFFSET);
227
228        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
229
230        // M = max(rowmax, Mold)
231        // P = e^(S - M)
232        // eM = e^(Mold - M)
233        coopMatPerElementNV(M, rowmax, Max, Mold);
234        coopMatPerElementNV(P, S - M, Exp);
235        coopMatPerElementNV(eM, Mold - M, Exp);
236
237        // Clear padding elements to 0, so they don't contribute to rowsum
238        if (Clamp != 0 &&
239            ((j + 1) * Bc > KV ||
240             (i + 1) * Br > N)) {
241
242            uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;
243            uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
244
245            coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
246        }
247
248        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
249
250        // compute rowsum by multiplying by matrix of all ones.
251        coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
252
253        rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
254        rowsum = coopMatMulAdd(P_A, One, rowsum);
255
256        coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
257        uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
258        coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
259
260        L = eM*L + rowsum;
261
262        // This is the "diagonal" matrix in the paper, but since we do componentwise
263        // multiply rather than matrix multiply it has the diagonal element smeared
264        // across the row
265        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
266
267        // resize eM by using smear/reduce
268        coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
269
270        O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
271        O = coopMatMulAdd(P_A, V, O);
272    }
273
274    // If there is split_k, then the split_k resolve shader does the final
275    // division by L. Store the intermediate O value and per-row m and L values.
276    if (p.k_num > 1) {
277        coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
278
279        // note: O and Q have swapped coord 1,2.
280        uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
281        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
282
283        o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
284        coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
285        coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
286        return;
287    }
288
289    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
290
291    // resize L by using smear/reduce
292    coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
293
294    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
295        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
296        coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
297
298        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
299
300        // resize M by using smear/reduce
301        coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
302
303        // O, Ldiag, Mr all have the same type so all element locations match
304        [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
305            ACC_TYPE sink = S[i];
306
307            ACC_TYPE ms = ACC_TYPE(1.0f);
308            ACC_TYPE vs = ACC_TYPE(1.0f);
309
310            if (sink > Mr[i]) {
311                ms = exp(Mr[i] - sink);
312
313                O[i] *= float16_t(ms);
314            } else {
315                vs = exp(sink - Mr[i]);
316            }
317
318            Ldiag[i] = Ldiag[i]*ms + vs;
319        }
320    }
321
322    [[unroll]]
323    for (int k = 0; k < Ldiag.length(); ++k) {
324        Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
325    }
326
327    coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
328
329    O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(Ldiag)*O_D;
330
331#if defined(ACC_TYPE_MAX)
332    [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }
333#endif
334
335    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
336
337    if (p.gqa_ratio > 1) {
338        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
339    } else {
340        tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
341        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
342
343        // permute dimensions
344        tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
345
346        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
347    }
348}