1#include "common.cuh"
2#include "fattn-common.cuh"
3#include "fattn-mma-f16.cuh"
4#include "fattn-tile.cuh"
5#include "fattn-vec.cuh"
6#include "fattn-wmma-f16.cuh"
7#include "fattn.cuh"
8
9template <int DKQ, int DV, int ncols2>
10static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
11 const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
12 const ggml_tensor * Q = dst->src[0];
13
14 if constexpr (ncols2 <= 8) {
15 if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
16 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
17 return;
18 }
19 }
20
21 if constexpr (ncols2 <= 16) {
22 if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
23 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
24 return;
25 }
26 }
27
28 if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
29 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
30 return;
31 }
32
33 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
34}
35
36template <int DKQ, int DV>
37static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
38 const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
39 const ggml_tensor * KQV = dst;
40 const ggml_tensor * Q = dst->src[0];
41 const ggml_tensor * K = dst->src[1];
42 const ggml_tensor * V = dst->src[2];
43 const ggml_tensor * mask = dst->src[3];
44
45 float max_bias = 0.0f;
46 memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
47
48 // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
49 // are put into the template specialization without GQA optimizations.
50 bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
51 for (const ggml_tensor * t : {Q, K, V, mask}) {
52 if (t == nullptr || ggml_is_quantized(t->type)) {
53 continue;
54 }
55 for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
56 if (t->nb[i] % 16 != 0) {
57 use_gqa_opt = false;
58 break;
59 }
60 }
61 }
62
63 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
64 const int gqa_ratio = Q->ne[2] / K->ne[2];
65
66 // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
67 if (cc == GGML_CUDA_CC_VOLTA) {
68 if (use_gqa_opt && gqa_ratio % 8 == 0) {
69 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
70 return;
71 }
72
73 if (use_gqa_opt && gqa_ratio % 4 == 0) {
74 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
75 return;
76 }
77
78 if (use_gqa_opt && gqa_ratio % 2 == 0) {
79 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
80 return;
81 }
82
83 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
84 return;
85 }
86
87 if (use_gqa_opt && gqa_ratio > 4) {
88 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
89 return;
90 }
91
92 if (use_gqa_opt && gqa_ratio > 2) {
93 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
94 return;
95 }
96
97 if (use_gqa_opt && gqa_ratio > 1) {
98 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
99 return;
100 }
101
102 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
103}
104
105static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
106 const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
107 const ggml_tensor * KQV = dst;
108 const ggml_tensor * Q = dst->src[0];
109 const ggml_tensor * K = dst->src[1];
110 const ggml_tensor * V = dst->src[2];
111 const ggml_tensor * mask = dst->src[3];
112
113 switch (Q->ne[0]) {
114 case 64:
115 GGML_ASSERT(V->ne[0] == 64);
116 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
117 break;
118 case 80:
119 GGML_ASSERT(V->ne[0] == 80);
120 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
121 break;
122 case 96:
123 GGML_ASSERT(V->ne[0] == 96);
124 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
125 break;
126 case 112:
127 GGML_ASSERT(V->ne[0] == 112);
128 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
129 break;
130 case 128:
131 GGML_ASSERT(V->ne[0] == 128);
132 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
133 break;
134 case 256:
135 GGML_ASSERT(V->ne[0] == 256);
136 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
137 break;
138 case 576: {
139 // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
140 GGML_ASSERT(V->ne[0] == 512);
141 float max_bias = 0.0f;
142 memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
143
144 const bool use_gqa_opt = mask && max_bias == 0.0f;
145 GGML_ASSERT(use_gqa_opt);
146
147 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
148 const int gqa_ratio = Q->ne[2] / K->ne[2];
149 if (gqa_ratio == 20) { // GLM 4.7 Flash
150 if (cc >= GGML_CUDA_CC_DGX_SPARK) {
151 if (Q->ne[1] <= 8) {
152 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
153 break;
154 }
155 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
156 break;
157 }
158 if (cc >= GGML_CUDA_CC_BLACKWELL) {
159 if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
160 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
161 break;
162 }
163 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
164 break;
165 }
166 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
167 if (Q->ne[1] <= 4) {
168 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
169 break;
170 }
171 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
172 break;
173 }
174 if (cc >= GGML_CUDA_CC_TURING) {
175 if (Q->ne[1] <= 4) {
176 if (K->ne[1] <= 16384) {
177 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
178 break;
179 }
180 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
181 break;
182 }
183 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
184 break;
185 }
186 // Volta:
187 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
188 } else if (gqa_ratio % 16 == 0) {
189 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
190 } else {
191 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
192 }
193 } break;
194 default:
195 GGML_ABORT("fatal error");
196 break;
197 }
198}
199
200#define FATTN_VEC_CASE(D, type_K, type_V) \
201 { \
202 const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
203 const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
204 if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
205 ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
206 return; \
207 } \
208 } \
209
210#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
211 FATTN_VEC_CASE( 64, type_K, type_V) \
212 FATTN_VEC_CASE(128, type_K, type_V) \
213 FATTN_VEC_CASE(256, type_K, type_V) \
214
215static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
216 ggml_tensor * Q = dst->src[0];
217 ggml_tensor * K = dst->src[1];
218 ggml_tensor * V = dst->src[2];
219
220#ifdef GGML_CUDA_FA_ALL_QUANTS
221 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
222 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
223 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
224 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
225 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
226 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
227
228 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
229 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
230 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
231 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
232 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
233 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
234
235 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
236 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
237 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
238 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
239 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
240 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
241
242 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
243 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
244 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
245 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
246 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
247 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
248
249 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
250 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
251 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
252 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
253 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
254 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
255
256 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
257 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
258 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
259 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
260 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
261 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
262#else
263 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
264 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
265 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
266#endif // GGML_CUDA_FA_ALL_QUANTS
267
268 GGML_ABORT("fatal error");
269}
270
271// Best FlashAttention kernel for a specific GPU:
272enum best_fattn_kernel {
273 BEST_FATTN_KERNEL_NONE = 0,
274 BEST_FATTN_KERNEL_TILE = 200,
275 BEST_FATTN_KERNEL_VEC = 100,
276 BEST_FATTN_KERNEL_WMMA_F16 = 300,
277 BEST_FATTN_KERNEL_MMA_F16 = 400,
278};
279
280static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
281#ifndef FLASH_ATTN_AVAILABLE
282 GGML_UNUSED(device); GGML_UNUSED(dst);
283 return BEST_FATTN_KERNEL_NONE;
284#endif// FLASH_ATTN_AVAILABLE
285
286 const ggml_tensor * KQV = dst;
287 const ggml_tensor * Q = dst->src[0];
288 const ggml_tensor * K = dst->src[1];
289 const ggml_tensor * V = dst->src[2];
290 const ggml_tensor * mask = dst->src[3];
291
292 const int gqa_ratio = Q->ne[2] / K->ne[2];
293 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
294
295 float max_bias = 0.0f;
296 memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
297
298 // The effective batch size for the kernel can be increased by gqa_ratio.
299 // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
300 bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
301 for (const ggml_tensor * t : {Q, K, V, mask}) {
302 if (t == nullptr || ggml_is_quantized(t->type)) {
303 continue;
304 }
305 for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
306 if (t->nb[i] % 16 != 0) {
307 gqa_opt_applies = false;
308 break;
309 }
310 }
311 }
312
313 const int cc = ggml_cuda_info().devices[device].cc;
314
315 switch (K->ne[0]) {
316 case 40:
317 case 64:
318 case 72:
319 case 80:
320 case 96:
321 case 128:
322 case 112:
323 case 256:
324 if (V->ne[0] != K->ne[0]) {
325 return BEST_FATTN_KERNEL_NONE;
326 }
327 break;
328 case 576:
329 if (V->ne[0] != 512) {
330 return BEST_FATTN_KERNEL_NONE;
331 }
332 if (!gqa_opt_applies) {
333 return BEST_FATTN_KERNEL_NONE;
334 }
335 break;
336 default:
337 return BEST_FATTN_KERNEL_NONE;
338 }
339
340#ifndef GGML_CUDA_FA_ALL_QUANTS
341 if (K->type != V->type) {
342 return BEST_FATTN_KERNEL_NONE;
343 }
344#endif // GGML_CUDA_FA_ALL_QUANTS
345
346 switch (K->type) {
347 case GGML_TYPE_F32:
348 case GGML_TYPE_F16:
349 break;
350 case GGML_TYPE_Q4_1:
351 case GGML_TYPE_Q5_0:
352 case GGML_TYPE_Q5_1:
353#ifndef GGML_CUDA_FA_ALL_QUANTS
354 return BEST_FATTN_KERNEL_NONE;
355#endif // GGML_CUDA_FA_ALL_QUANTS
356 case GGML_TYPE_Q4_0:
357 case GGML_TYPE_Q8_0:
358 break;
359 default:
360 return BEST_FATTN_KERNEL_NONE;
361 }
362
363 if (mask && mask->ne[2] != 1) {
364 return BEST_FATTN_KERNEL_NONE;
365 }
366
367 // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
368 const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
369
370 // If Turing tensor cores are available, use them:
371 if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
372 if (can_use_vector_kernel) {
373 if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
374 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
375 return BEST_FATTN_KERNEL_VEC;
376 }
377 } else {
378 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
379 if (Q->ne[1] <= 2) {
380 return BEST_FATTN_KERNEL_VEC;
381 }
382 } else {
383 if (Q->ne[1] == 1) {
384 return BEST_FATTN_KERNEL_VEC;
385 }
386 }
387 }
388 if (!gqa_opt_applies && Q->ne[1] == 1) {
389 return BEST_FATTN_KERNEL_VEC;
390 }
391 }
392 return BEST_FATTN_KERNEL_MMA_F16;
393 }
394
395 if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
396 int gqa_ratio_eff = 1;
397 const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
398 while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
399 gqa_ratio_eff *= 2;
400 }
401 if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
402 return BEST_FATTN_KERNEL_VEC;
403 }
404 if (Q->ne[1] * gqa_ratio_eff <= 16) {
405 return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
406 }
407 return BEST_FATTN_KERNEL_MMA_F16;
408 }
409
410 // Use the WMMA kernel if possible:
411 if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
412 if (can_use_vector_kernel && Q->ne[1] <= 2) {
413 return BEST_FATTN_KERNEL_VEC;
414 }
415 return BEST_FATTN_KERNEL_WMMA_F16;
416 }
417
418 if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
419 if (can_use_vector_kernel) {
420 if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
421 if (Q->ne[1] == 1) {
422 if (!gqa_opt_applies) {
423 return BEST_FATTN_KERNEL_VEC;
424 }
425 }
426 } else {
427 if (Q->ne[1] <= 2) {
428 return BEST_FATTN_KERNEL_VEC;
429 }
430 }
431 }
432 int gqa_ratio_eff = 1;
433 const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
434 while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
435 gqa_ratio_eff *= 2;
436 }
437 if (Q->ne[1] * gqa_ratio_eff <= 8) {
438 return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
439 }
440 return BEST_FATTN_KERNEL_MMA_F16;
441 }
442
443 // If there are no tensor cores available, use the generic tile kernel:
444 if (can_use_vector_kernel) {
445 if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
446 if (Q->ne[1] == 1) {
447 if (!gqa_opt_applies) {
448 return BEST_FATTN_KERNEL_VEC;
449 }
450 }
451 } else {
452 if (Q->ne[1] <= 2) {
453 return BEST_FATTN_KERNEL_VEC;
454 }
455 }
456 }
457 return BEST_FATTN_KERNEL_TILE;
458}
459
460void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461 ggml_cuda_set_device(ctx.device);
462 switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
463 case BEST_FATTN_KERNEL_NONE:
464 GGML_ABORT("fatal error");
465 case BEST_FATTN_KERNEL_TILE:
466 ggml_cuda_flash_attn_ext_tile(ctx, dst);
467 break;
468 case BEST_FATTN_KERNEL_VEC:
469 ggml_cuda_flash_attn_ext_vec(ctx, dst);
470 break;
471 case BEST_FATTN_KERNEL_WMMA_F16:
472 ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
473 break;
474 case BEST_FATTN_KERNEL_MMA_F16:
475 ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
476 break;
477 }
478}
479
480bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
481 return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
482}