1#pragma once
2
3#include "common.cuh"
4#include "convert.cuh"
5#include "vecdotq.cuh"
6
7#include <cstdint>
8
9#define FATTN_KQ_STRIDE 256
10#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
11#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
12
13// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
14// by the VKQ accumulators is effectively being shifted up by a factor of 2.
15// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
16// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
17// Still, the value range should be shifted as much as necessary but as little as possible.
18// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
19#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
20
21typedef void (* fattn_kernel_t)(
22 const char * __restrict__ Q,
23 const char * __restrict__ K,
24 const char * __restrict__ V,
25 const char * __restrict__ mask,
26 const char * __restrict__ sinks,
27 const int * __restrict__ KV_max,
28 float * __restrict__ dst,
29 float2 * __restrict__ dst_meta,
30 const float scale,
31 const float max_bias,
32 const float m0,
33 const float m1,
34 const uint32_t n_head_log2,
35 const float logit_softcap,
36 const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
37 const int32_t nb01, const int32_t nb02, const int32_t nb03,
38 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
39 const int32_t nb11, const int32_t nb12, const int64_t nb13,
40 const int32_t nb21, const int32_t nb22, const int64_t nb23,
41 const int32_t ne31, const int32_t ne32, const int32_t ne33,
42 const int32_t nb31, const int32_t nb32, const int64_t nb33);
43
44typedef float (*vec_dot_KQ_t)(
45 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
46
47template <int D, int nthreads>
48static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
49 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
50
51 const half2 * K_h2 = (const half2 *) K_c;
52 GGML_UNUSED(Q_q8);
53 GGML_UNUSED(Q_ds_v);
54
55 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
56 constexpr int cpy_ne = cpy_nb / 4;
57
58 float sum = 0.0f;
59
60#pragma unroll
61 for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
62 __align__(16) half2 tmp[cpy_ne];
63 ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
64#pragma unroll
65 for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
66#ifdef V_DOT2_F32_F16_AVAILABLE
67 ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
68#else
69 ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
70#endif // V_DOT2_F32_F16_AVAILABLE
71 }
72 }
73
74 return sum;
75}
76
77template<int D, int nthreads>
78static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
79 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
80
81 const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
82 GGML_UNUSED(Q_v);
83
84 float sum = 0.0f;
85
86#pragma unroll
87 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
88 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
89
90 const int ib = k_KQ / QI8_1;
91 const int iqs4 = k_KQ % QI4_0;
92 const int shift = k_KQ & (QI8_1/2);
93
94 int v;
95 ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
96 v = (v >> shift) & 0x0F0F0F0F;
97 const int u = Q_q8[k_KQ_0/nthreads];
98
99 const int sumi = ggml_cuda_dp4a(v, u, 0);
100
101 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
102 sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
103 }
104
105 return sum;
106}
107
108template<int D, int nthreads>
109static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
110 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
111
112 const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
113 GGML_UNUSED(Q_v);
114
115 float sum = 0.0f;
116
117#pragma unroll
118 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
119 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
120
121 const int ib = k_KQ / QI8_1;
122 const int iqs4 = k_KQ % QI4_1;
123 const int shift = k_KQ & (QI8_1/2);
124
125 int v;
126 ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
127 v = (v >> shift) & 0x0F0F0F0F;
128 const int u = Q_q8[k_KQ_0/nthreads];
129
130 const int sumi = ggml_cuda_dp4a(v, u, 0);
131
132 const float2 K_dm = __half22float2(K_q4_1[ib].dm);
133 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
134
135 sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
136 }
137
138 return sum;
139}
140
141template<int D, int nthreads>
142static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
143 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
144
145 const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146 GGML_UNUSED(Q_v);
147
148 float sum = 0.0f;
149
150#pragma unroll
151 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
152 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
153
154 const int ib = k_KQ / QI8_1;
155 const int iqs4 = k_KQ % QI5_0;
156 const int iqs8 = k_KQ % QI8_1;
157 const int shift = k_KQ & (QI8_1/2);
158
159 int v;
160 ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
161 v = (v >> shift) & 0x0F0F0F0F;
162
163 {
164 int vh;
165 ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
166 vh >>= iqs8 * QI5_0;
167
168 v |= (vh << 4) & 0x00000010; // 0 -> 4
169 v |= (vh << 11) & 0x00001000; // 1 -> 12
170 v |= (vh << 18) & 0x00100000; // 2 -> 20
171 v |= (vh << 25) & 0x10000000; // 3 -> 28
172 }
173
174 const int u = Q_q8[k_KQ_0/nthreads];
175
176 const int sumi = ggml_cuda_dp4a(v, u, 0);
177
178 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
179
180 sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
181 }
182
183 return sum;
184}
185
186template<int D, int nthreads>
187static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
188 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189
190 const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
191 GGML_UNUSED(Q_v);
192
193 float sum = 0.0f;
194
195#pragma unroll
196 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
197 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
198
199 const int ib = k_KQ / QI8_1;
200 const int iqs4 = k_KQ % QI5_1;
201 const int iqs8 = k_KQ % QI8_1;
202 const int shift = k_KQ & (QI8_1/2);
203
204 int v;
205 ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
206 v = (v >> shift) & 0x0F0F0F0F;
207
208 {
209 int vh;
210 ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
211 vh >>= iqs8 * QI5_0;
212
213 v |= (vh << 4) & 0x00000010; // 0 -> 4
214 v |= (vh << 11) & 0x00001000; // 1 -> 12
215 v |= (vh << 18) & 0x00100000; // 2 -> 20
216 v |= (vh << 25) & 0x10000000; // 3 -> 28
217 }
218
219 const int u = Q_q8[k_KQ_0/nthreads];
220
221 const int sumi = ggml_cuda_dp4a(v, u, 0);
222
223 const float2 K_dm = __half22float2(K_q5_1[ib].dm);
224 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
225
226 sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
227 }
228
229 return sum;
230}
231
232template <int D, int nthreads>
233static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
234 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
235
236 const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
237 GGML_UNUSED(Q_v);
238
239 float sum = 0.0f;
240
241#pragma unroll
242 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
243 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
244
245 const int ib = k_KQ / QI8_0;
246 const int iqs = k_KQ % QI8_0;
247
248 int v;
249 ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
250
251 const float2 * Q_ds = (const float2 *) Q_ds_v;
252 const float Q_d = Q_ds[k_KQ_0/nthreads].x;
253
254 sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
255 }
256
257 return sum;
258}
259
260template <typename Tds, int ni>
261static __device__ __forceinline__ void quantize_q8_1_to_shared(
262 const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
263
264 float vals[sizeof(int)] = {0.0f};
265#pragma unroll
266 for (int l = 0; l < int(sizeof(int)); ++l) {
267 vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
268 }
269
270 float amax = fabsf(vals[0]);
271 float sum = vals[0];
272#pragma unroll
273 for (int l = 1; l < int(sizeof(int)); ++l) {
274 amax = fmaxf(amax, fabsf(vals[l]));
275 sum += vals[l];
276 }
277#pragma unroll
278 for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
279 amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
280 sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
281 }
282
283 const float d = amax / 127;
284 int q32 = 0;
285 int8_t * q8 = (int8_t *) &q32;
286
287 if (d != 0.0f) {
288#pragma unroll
289 for (int l = 0; l < int(sizeof(int)); ++l) {
290 q8[l] = roundf(vals[l] / d);
291 }
292 }
293
294 yq32[threadIdx.x] = q32;
295 if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
296 if (std::is_same<Tds, half2>::value) {
297 ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
298 } else {
299 ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
300 }
301 }
302}
303
304typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
305
306template <typename T, int ne>
307static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
308 if constexpr (std::is_same_v<T, half>) {
309 ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
310 } else if constexpr (std::is_same_v<T, float>) {
311 static_assert(ne % 2 == 0, "bad ne");
312 __align__(16) half2 tmp[ne/2];
313 ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
314 float2 * dst_f2 = (float2 *) dst;
315#pragma unroll
316 for (int l = 0; l < ne/2; ++l) {
317 dst_f2[l] = __half22float2(tmp[l]);
318 }
319 } else {
320 static_assert(std::is_same_v<T, void>, "unsupported type");
321 }
322}
323
324template <typename T, int ne>
325static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
326 const block_q4_0 * x = (const block_q4_0 *) vx;
327
328 const int64_t ib = i0 / QK4_0;
329 const int iqs = i0 % (QK4_0/2);
330 const int shift = (i0 % QK4_0) / (QK4_0/2);
331
332 int q;
333 static_assert(ne == 2 || ne == 4, "bad ne");
334 ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
335 q >>= 4*shift;
336 q &= 0x0F0F0F0F;
337 q = __vsubss4(q, 0x08080808);
338
339 const int8_t * q8 = (const int8_t *) &q;
340
341#ifdef FP16_AVAILABLE
342 if constexpr (std::is_same_v<T, half>) {
343 const half2 d = __half2half2(x[ib].d);
344
345#pragma unroll
346 for (int l0 = 0; l0 < ne; l0 += 2) {
347 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
348 }
349 } else
350#endif // FP16_AVAILABLE
351 if constexpr (std::is_same_v<T, float>) {
352 const float d = x[ib].d;
353
354#pragma unroll
355 for (int l = 0; l < ne; ++l) {
356 ((float *) dst)[l] = d * q8[l];
357 }
358 } else {
359 static_assert(std::is_same_v<T, void>, "bad type");
360 }
361}
362
363template <typename T, int ne>
364static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
365 const block_q4_1 * x = (const block_q4_1 *) vx;
366
367 const int64_t ib = i0 / QK4_1;
368 const int iqs = i0 % (QK4_1/2);
369 const int shift = (i0 % QK4_1) / (QK4_1/2);
370
371 int q;
372 static_assert(ne == 2 || ne == 4, "bad ne");
373 ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
374 q >>= 4*shift;
375 q &= 0x0F0F0F0F;
376
377 const int8_t * q8 = (const int8_t *) &q;
378
379#ifdef FP16_AVAILABLE
380 if constexpr (std::is_same_v<T, half>) {
381 const half2 dm = x[ib].dm;
382 const half2 d = __half2half2( __low2half(dm));
383 const half2 m = __half2half2(__high2half(dm));
384
385#pragma unroll
386 for (int l0 = 0; l0 < ne; l0 += 2) {
387 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
388 }
389 } else
390#endif // FP16_AVAILABLE
391 if constexpr (std::is_same_v<T, float>) {
392 const float2 dm = __half22float2(x[ib].dm);
393
394#pragma unroll
395 for (int l = 0; l < ne; ++l) {
396 ((float *) dst)[l] = dm.x * q8[l] + dm.y;
397 }
398 } else {
399 static_assert(std::is_same_v<T, void>, "bad type");
400 }
401}
402
403template <typename T, int ne>
404static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
405 const block_q5_0 * x = (const block_q5_0 *) vx;
406
407 const int64_t ib = i0 / QK5_0;
408 const int idq = i0 % QK5_0;
409 const int iqs = i0 % (QK5_0/2);
410 const int shift = (i0 % QK5_0) / (QK5_0/2);
411
412 int q;
413 static_assert(ne == 2 || ne == 4, "bad ne");
414 ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
415 q >>= 4*shift;
416 q &= 0x0F0F0F0F;
417
418 {
419 int qh;
420 ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
421#pragma unroll
422 for (int l = 0; l < ne; ++l) {
423 q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
424 }
425 }
426
427 q = __vsubss4(q, 0x10101010);
428
429 const int8_t * q8 = (const int8_t *) &q;
430
431#ifdef FP16_AVAILABLE
432 if constexpr (std::is_same_v<T, half>) {
433 const half2 d = __half2half2(x[ib].d);
434
435#pragma unroll
436 for (int l0 = 0; l0 < ne; l0 += 2) {
437 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
438 }
439 } else
440#endif // FP16_AVAILABLE
441 if constexpr (std::is_same_v<T, float>) {
442 const float d = x[ib].d;
443
444#pragma unroll
445 for (int l = 0; l < ne; ++l) {
446 ((float *) dst)[l] = d * q8[l];
447 }
448 } else {
449 static_assert(std::is_same_v<T, void>, "bad type");
450 }
451}
452
453template <typename T, int ne>
454static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
455 const block_q5_1 * x = (const block_q5_1 *) vx;
456
457 const int64_t ib = i0 / QK5_1;
458 const int idq = i0 % QK5_1;
459 const int iqs = i0 % (QK5_1/2);
460 const int shift = (i0 % QK5_1) / (QK5_1/2);
461
462 int q;
463 static_assert(ne == 2 || ne == 4, "bad ne");
464 ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
465 q >>= 4*shift;
466 q &= 0x0F0F0F0F;
467
468 {
469 int qh;
470 ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
471#pragma unroll
472 for (int l = 0; l < ne; ++l) {
473 q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
474 }
475 }
476
477 const int8_t * q8 = (const int8_t *) &q;
478
479#ifdef FP16_AVAILABLE
480 if constexpr (std::is_same_v<T, half>) {
481 const half2 dm = x[ib].dm;
482 const half2 d = __half2half2( __low2half(dm));
483 const half2 m = __half2half2(__high2half(dm));
484
485#pragma unroll
486 for (int l0 = 0; l0 < ne; l0 += 2) {
487 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
488 }
489 } else
490#endif // FP16_AVAILABLE
491 if constexpr (std::is_same_v<T, float>) {
492 const float2 dm = __half22float2(x[ib].dm);
493
494#pragma unroll
495 for (int l = 0; l < ne; ++l) {
496 ((float *) dst)[l] = dm.x * q8[l] + dm.y;
497 }
498 } else {
499 static_assert(std::is_same_v<T, void>, "bad type");
500 }
501}
502
503template <typename T, int ne>
504static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
505 const block_q8_0 * x = (const block_q8_0 *) vx;
506
507 const int64_t ib = i0 / QK8_0;
508 const int iqs = i0 % QK8_0;
509
510 static_assert(ne % 2 == 0, "bad ne");
511 int8_t qs[ne];
512 ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
513
514#ifdef FP16_AVAILABLE
515 if constexpr (std::is_same<T, half>::value) {
516 const half2 d = __half2half2(x[ib].d);
517
518#pragma unroll
519 for (int l0 = 0; l0 < ne; l0 += 2) {
520 ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
521 }
522 } else
523#endif // FP16_AVAILABLE
524 if constexpr (std::is_same<T, float>::value) {
525 const float d = x[ib].d;
526
527#pragma unroll
528 for (int l = 0; l < ne; ++l) {
529 ((float *) dst)[l] = d * qs[l];
530 }
531 } else {
532 static_assert(std::is_same_v<T, void>, "unsupported type");
533 }
534}
535
536template <ggml_type type_K, int D, int nthreads>
537constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
538 if constexpr (type_K == GGML_TYPE_F16) {
539 return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
540 } else if constexpr (type_K == GGML_TYPE_Q4_0) {
541 return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
542 } else if constexpr (type_K == GGML_TYPE_Q4_1) {
543 return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
544 } else if constexpr (type_K == GGML_TYPE_Q5_0) {
545 return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
546 } else if constexpr (type_K == GGML_TYPE_Q5_1) {
547 return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
548 } else if constexpr (type_K == GGML_TYPE_Q8_0) {
549 return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
550 } else {
551 static_assert(type_K == -1, "bad type");
552 return nullptr;
553 }
554}
555
556template <ggml_type type_V, typename T, int ne>
557constexpr __device__ dequantize_V_t get_dequantize_V() {
558 if constexpr (type_V == GGML_TYPE_F16) {
559 return dequantize_V_f16<T, ne>;
560 } else if constexpr (type_V == GGML_TYPE_Q4_0) {
561 return dequantize_V_q4_0<T, ne>;
562 } else if constexpr (type_V == GGML_TYPE_Q4_1) {
563 return dequantize_V_q4_1<T, ne>;
564 } else if constexpr (type_V == GGML_TYPE_Q5_0) {
565 return dequantize_V_q5_0<T, ne>;
566 } else if constexpr (type_V == GGML_TYPE_Q5_1) {
567 return dequantize_V_q5_1<T, ne>;
568 } else if constexpr (type_V == GGML_TYPE_Q8_0) {
569 return dequantize_V_q8_0<T, ne>;
570 } else {
571 static_assert(type_V == -1, "bad type");
572 return nullptr;
573 }
574}
575
576template <int ncols1>
577__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
578static __global__ void flash_attn_mask_to_KV_max(
579 const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
580 const int ne31 = gridDim.x;
581 const int tid = threadIdx.x;
582 const int sequence = blockIdx.y;
583 const int jt = blockIdx.x;
584
585 mask += sequence*s33 + jt*ncols1*s31;
586
587 __shared__ int buf_iw[WARP_SIZE];
588 if (tid < WARP_SIZE) {
589 buf_iw[tid] = 1;
590 }
591 __syncthreads();
592
593 int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
594 for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
595 int all_inf = 1;
596
597#pragma unroll
598 for (int j = 0; j < ncols1; ++j) {
599 const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
600 all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
601 }
602
603 all_inf = warp_reduce_all(all_inf);
604 if (tid % WARP_SIZE == 0) {
605 buf_iw[tid / WARP_SIZE] = all_inf;
606 }
607 __syncthreads();
608 all_inf = buf_iw[tid % WARP_SIZE];
609 __syncthreads();
610 all_inf = warp_reduce_all(all_inf);
611
612 if (!all_inf) {
613 break;
614 }
615 }
616
617 // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
618 // If the break was triggered it's the lower edge of the tile with the first non-masked values.
619 // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
620 KV_max_sj += FATTN_KQ_STRIDE;
621
622 if (threadIdx.x != 0) {
623 return;
624 }
625
626 KV_max[sequence*ne31 + jt] = KV_max_sj;
627}
628
629template<int D, int ncols1, int ncols2> // D == head size
630__launch_bounds__(D, 1)
631static __global__ void flash_attn_stream_k_fixup(
632 float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
633 const int ne11, const int ne12, const int nbatch_fa) {
634 constexpr int ncols = ncols1*ncols2;
635
636 const int bidx0 = blockIdx.x;
637 const int j = blockIdx.y;
638 const int c = blockIdx.z;
639 const int jc = j*ncols2 + c;
640 const int tid = threadIdx.x;
641
642 const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
643
644 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
645
646 const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
647 const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
648 const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
649
650 const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
651 const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
652
653 const bool did_not_have_any_data = kbc0 == kbc0_stop;
654 const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
655 const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
656 if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
657 return;
658 }
659
660 // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
661 const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
662 const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
663 const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
664 const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
665
666 const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
667
668 if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
669 return;
670 }
671
672 dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
673
674 // Load the partial result that needs a fixup:
675 float dst_val = 0.0f;
676 float max_val = 0.0f;
677 float rowsum = 0.0f;
678 {
679 dst_val = *dst;
680
681 const float2 tmp = dst_fixup[bidx0*ncols + jc];
682 max_val = tmp.x;
683 rowsum = tmp.y;
684 }
685
686 // Iterate over previous blocks and compute the combined results.
687 // All CUDA blocks that get here must have a previous block that needs a fixup.
688 int bidx = bidx0 - 1;
689 int kbc_stop = kbc0;
690 while(true) {
691 const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
692 if (kbc == kbc_stop) { // Did not have any data.
693 bidx--;
694 kbc_stop = kbc;
695 continue;
696 }
697
698 const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
699
700 const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
701
702 // Scale the current and new value accumulators depending on the max. values.
703 const float max_val_new = fmaxf(max_val, tmp.x);
704
705 const float diff_val = max_val - max_val_new;
706 const float diff_add = tmp.x - max_val_new;
707
708 const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
709 const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
710
711 dst_val = scale_val*dst_val + scale_add*dst_add;
712 rowsum = scale_val*rowsum + scale_add*tmp.y;
713
714 max_val = max_val_new;
715
716 // If this block started in a previous tile we are done and don't need to combine additional partial results.
717 if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
718 break;
719 }
720 bidx--;
721 kbc_stop = kbc;
722 }
723
724 // Write back final result:
725 *dst = dst_val / rowsum;
726}
727
728template<int D> // D == head size
729__launch_bounds__(D, 1)
730static __global__ void flash_attn_combine_results(
731 const float * __restrict__ VKQ_parts,
732 const float2 * __restrict__ VKQ_meta,
733 float * __restrict__ dst,
734 const int parallel_blocks) {
735 // Dimension 0: threadIdx.x
736 // Dimension 1: blockIdx.x
737 // Dimension 2: blockIdx.y
738 // Dimension 3: blockIdx.z
739 // Memory layout is permuted with [0, 2, 1, 3]
740
741 const int ne01 = gridDim.x;
742 const int ne02 = gridDim.y;
743
744 const int col = blockIdx.x;
745 const int head = blockIdx.y;
746 const int sequence = blockIdx.z;
747
748 const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
749
750 VKQ_parts += j_dst_unrolled * parallel_blocks*D;
751 VKQ_meta += j_dst_unrolled * parallel_blocks;
752 dst += j_dst_unrolled * D;
753
754 const int tid = threadIdx.x;
755 __builtin_assume(tid < D);
756
757 extern __shared__ float2 meta[];
758 for (int i = tid; i < 2*parallel_blocks; i += D) {
759 ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
760 }
761
762 __syncthreads();
763
764 float kqmax = meta[0].x;
765 for (int l = 1; l < parallel_blocks; ++l) {
766 kqmax = max(kqmax, meta[l].x);
767 }
768
769 float VKQ_numerator = 0.0f;
770 float VKQ_denominator = 0.0f;
771 for (int l = 0; l < parallel_blocks; ++l) {
772 const float KQ_max_scale = expf(meta[l].x - kqmax);
773
774 VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
775 VKQ_denominator += KQ_max_scale * meta[l].y;
776 }
777
778 dst[tid] = VKQ_numerator / VKQ_denominator;
779}
780
781template <int DV, int ncols1, int ncols2>
782void launch_fattn(
783 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
784 const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
785) {
786 constexpr int ncols = ncols1 * ncols2;
787
788 const ggml_tensor * Q = dst->src[0];
789 const ggml_tensor * K = dst->src[1];
790 const ggml_tensor * V = dst->src[2];
791
792 const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
793
794 const ggml_tensor * mask = dst->src[3];
795 const ggml_tensor * sinks = dst->src[4];
796
797 ggml_tensor * KQV = dst;
798
799 GGML_ASSERT(Q->type == GGML_TYPE_F32);
800 GGML_ASSERT(KQV->type == GGML_TYPE_F32);
801
802 GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
803 GGML_ASSERT(K->nb[0] == ggml_element_size(K));
804 GGML_ASSERT(V->nb[0] == ggml_element_size(V));
805
806 GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
807
808 ggml_cuda_pool & pool = ctx.pool();
809 cudaStream_t main_stream = ctx.stream();
810 const int id = ggml_cuda_get_device();
811 const int cc = ggml_cuda_info().devices[id].cc;
812 const int nsm = ggml_cuda_info().devices[id].nsm;
813
814 ggml_cuda_pool_alloc<half> K_f16(pool);
815 ggml_cuda_pool_alloc<half> V_f16(pool);
816 ggml_cuda_pool_alloc<int> KV_max(pool);
817 ggml_cuda_pool_alloc<float> dst_tmp(pool);
818 ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
819
820 const char * K_data = (const char *) K->data;
821 size_t nb11 = K->nb[1];
822 size_t nb12 = K->nb[2];
823 size_t nb13 = K->nb[3];
824
825 const char * V_data = (const char *) V->data;
826 size_t nb21 = V->nb[1];
827 size_t nb22 = V->nb[2];
828 size_t nb23 = V->nb[3];
829
830 if (need_f16_K && K->type != GGML_TYPE_F16) {
831 const size_t bs = ggml_blck_size(K->type);
832 const size_t ts = ggml_type_size(K->type);
833
834 K_f16.alloc(ggml_nelements(K));
835 if (ggml_is_contiguously_allocated(K)) {
836 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
837 to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
838
839 nb11 = nb11*bs*sizeof(half)/ts;
840 nb12 = nb12*bs*sizeof(half)/ts;
841 nb13 = nb13*bs*sizeof(half)/ts;
842 } else {
843 GGML_ASSERT(K->nb[0] == ts);
844 to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
845 const int64_t s01 = nb11 / ts;
846 const int64_t s02 = nb12 / ts;
847 const int64_t s03 = nb13 / ts;
848 to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
849
850 nb11 = K->ne[0] * sizeof(half);
851 nb12 = K->ne[1] * nb11;
852 nb13 = K->ne[2] * nb12;
853 }
854 K_data = (char *) K_f16.ptr;
855 }
856
857 if (need_f16_V && V->type != GGML_TYPE_F16) {
858 if (V_is_K_view) {
859 V_data = K_data;
860 nb21 = nb11;
861 nb22 = nb12;
862 nb23 = nb13;
863 } else {
864 const size_t bs = ggml_blck_size(V->type);
865 const size_t ts = ggml_type_size(V->type);
866
867 V_f16.alloc(ggml_nelements(V));
868 if (ggml_is_contiguously_allocated(V)) {
869 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
870 to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
871 V_data = (char *) V_f16.ptr;
872
873 nb21 = nb21*bs*sizeof(half)/ts;
874 nb22 = nb22*bs*sizeof(half)/ts;
875 nb23 = nb23*bs*sizeof(half)/ts;
876 } else {
877 GGML_ASSERT(V->nb[0] == ts);
878 to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
879 const int64_t s01 = nb21 / ts;
880 const int64_t s02 = nb22 / ts;
881 const int64_t s03 = nb23 / ts;
882 to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
883
884 nb21 = V->ne[0] * sizeof(half);
885 nb22 = V->ne[1] * nb21;
886 nb23 = V->ne[2] * nb22;
887 }
888 V_data = (char *) V_f16.ptr;
889 }
890 }
891
892 const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
893 const int gqa_ratio = Q->ne[2] / K->ne[2];
894 const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
895 const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
896
897 // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
898 // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
899 // multiple sequences of possibly different lengths.
900 if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
901 const int s31 = mask->nb[1] / sizeof(half2);
902 const int s33 = mask->nb[3] / sizeof(half2);
903
904 const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
905 const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
906
907 const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
908 const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
909
910 KV_max.alloc(ne_KV_max);
911 flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
912 ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
913 CUDA_CHECK(cudaGetLastError());
914 }
915
916 const dim3 block_dim(warp_size, nwarps, 1);
917 int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
918 CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
919 GGML_ASSERT(max_blocks_per_sm > 0);
920 int parallel_blocks = max_blocks_per_sm;
921
922 dim3 blocks_num;
923 if (stream_k) {
924 // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
925 const int max_blocks = max_blocks_per_sm*nsm;
926 const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
927 const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
928
929 const int nblocks_stream_k = max_blocks;
930
931 const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
932
933 blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
934 blocks_num.y = 1;
935 blocks_num.z = 1;
936
937 if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
938 dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
939 }
940 } else {
941 const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
942
943 // parallel_blocks must not be larger than what the tensor size allows:
944 parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
945
946 // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
947 // Test whether parallel_blocks can be set to a higher value for better efficiency.
948 const int blocks_per_wave = nsm * max_blocks_per_sm;
949 int nwaves_best = 0;
950 int efficiency_percent_best = 0;
951 for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
952 const int nblocks_total = ntiles_total * parallel_blocks_test;
953 const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
954 const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
955
956 // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
957 if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
958 break;
959 }
960
961 if (efficiency_percent > efficiency_percent_best) {
962 nwaves_best = nwaves;
963 efficiency_percent_best = efficiency_percent;
964 parallel_blocks = parallel_blocks_test;
965 }
966 }
967
968 blocks_num.x = ntiles_x;
969 blocks_num.y = parallel_blocks;
970 blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
971
972 if (parallel_blocks > 1) {
973 dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
974 dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
975 }
976 }
977
978 float scale = 1.0f;
979 float max_bias = 0.0f;
980 float logit_softcap = 0.0f;
981
982 memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
983 memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
984 memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
985
986 if (logit_softcap != 0.0f) {
987 scale /= logit_softcap;
988 }
989
990 const uint32_t n_head = Q->ne[2];
991 const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
992
993 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
994 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
995
996 // TODO other tensor dimensions after removal of WMMA kernel:
997 const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
998
999 GGML_ASSERT(block_dim.x % warp_size == 0);
1000 fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
1001 (const char *) Q->data,
1002 K_data,
1003 V_data,
1004 mask ? ((const char *) mask->data) : nullptr,
1005 sinks ? ((const char *) sinks->data) : nullptr,
1006 KV_max.ptr,
1007 !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
1008 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
1009 Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
1010 K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
1011 nb21, nb22, nb23,
1012 mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
1013 mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
1014 );
1015 CUDA_CHECK(cudaGetLastError());
1016
1017 if (stream_k) {
1018 if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1019 const dim3 block_dim_combine(DV, 1, 1);
1020 const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
1021
1022 flash_attn_stream_k_fixup<DV, ncols1, ncols2>
1023 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
1024 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
1025 }
1026 } else if (parallel_blocks > 1) {
1027 const dim3 block_dim_combine(DV, 1, 1);
1028 const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
1029 const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
1030
1031 flash_attn_combine_results<DV>
1032 <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
1033 (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
1034 }
1035 CUDA_CHECK(cudaGetLastError());
1036}