1#include "common.cuh"
  2#include "fattn-common.cuh"
  3#include "fattn-mma-f16.cuh"
  4#include "fattn-tile.cuh"
  5#include "fattn-vec.cuh"
  6#include "fattn-wmma-f16.cuh"
  7#include "fattn.cuh"
  8
  9template <int DKQ, int DV, int ncols2>
 10static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 11    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 12    const ggml_tensor * Q = dst->src[0];
 13
 14    if constexpr (ncols2 <= 8) {
 15        if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
 16            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
 17            return;
 18        }
 19    }
 20
 21    if constexpr (ncols2 <= 16) {
 22        if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
 23            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
 24            return;
 25        }
 26    }
 27
 28    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
 29        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
 30        return;
 31    }
 32
 33    ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
 34}
 35
 36template <int DKQ, int DV>
 37static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 38    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 39    const ggml_tensor * KQV  = dst;
 40    const ggml_tensor * Q    = dst->src[0];
 41    const ggml_tensor * K    = dst->src[1];
 42    const ggml_tensor * V    = dst->src[2];
 43    const ggml_tensor * mask = dst->src[3];
 44
 45    float max_bias = 0.0f;
 46    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 47
 48    // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
 49    //     are put into the template specialization without GQA optimizations.
 50    bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
 51    for (const ggml_tensor * t : {Q, K, V, mask}) {
 52        if (t == nullptr || ggml_is_quantized(t->type)) {
 53            continue;
 54        }
 55        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
 56            if (t->nb[i] % 16 != 0) {
 57                use_gqa_opt = false;
 58                break;
 59            }
 60        }
 61    }
 62
 63    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
 64    const int gqa_ratio = Q->ne[2] / K->ne[2];
 65
 66    // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
 67    if (cc == GGML_CUDA_CC_VOLTA) {
 68        if (use_gqa_opt && gqa_ratio % 8 == 0) {
 69            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
 70            return;
 71        }
 72
 73        if (use_gqa_opt && gqa_ratio % 4 == 0) {
 74            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
 75            return;
 76        }
 77
 78        if (use_gqa_opt && gqa_ratio % 2 == 0) {
 79            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
 80            return;
 81        }
 82
 83        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
 84        return;
 85    }
 86
 87    if (use_gqa_opt && gqa_ratio > 4) {
 88        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
 89        return;
 90    }
 91
 92    if (use_gqa_opt && gqa_ratio > 2) {
 93        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
 94        return;
 95    }
 96
 97    if (use_gqa_opt && gqa_ratio > 1) {
 98        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
 99        return;
100    }
101
102    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
103}
104
105static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
106    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
107    const ggml_tensor * KQV  = dst;
108    const ggml_tensor * Q    = dst->src[0];
109    const ggml_tensor * K    = dst->src[1];
110    const ggml_tensor * V    = dst->src[2];
111    const ggml_tensor * mask = dst->src[3];
112
113    switch (Q->ne[0]) {
114        case 64:
115            GGML_ASSERT(V->ne[0] == 64);
116            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64,  64>(ctx, dst);
117            break;
118        case 80:
119            GGML_ASSERT(V->ne[0] == 80);
120            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80,  80>(ctx, dst);
121            break;
122        case 96:
123            GGML_ASSERT(V->ne[0] == 96);
124            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96,  96>(ctx, dst);
125            break;
126        case 112:
127            GGML_ASSERT(V->ne[0] == 112);
128            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
129            break;
130        case 128:
131            GGML_ASSERT(V->ne[0] == 128);
132            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
133            break;
134        case 256:
135            GGML_ASSERT(V->ne[0] == 256);
136            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
137            break;
138        case 576: {
139            // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
140            GGML_ASSERT(V->ne[0] == 512);
141            float max_bias = 0.0f;
142            memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
143
144            const bool use_gqa_opt = mask && max_bias == 0.0f;
145            GGML_ASSERT(use_gqa_opt);
146
147            GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
148            const int gqa_ratio = Q->ne[2] / K->ne[2];
149            if (gqa_ratio == 20) { // GLM 4.7 Flash
150                if (cc >= GGML_CUDA_CC_DGX_SPARK) {
151                    if (Q->ne[1] <= 8) {
152                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
153                        break;
154                    }
155                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
156                    break;
157                }
158                if (cc >= GGML_CUDA_CC_BLACKWELL) {
159                    if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
160                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
161                        break;
162                    }
163                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
164                    break;
165                }
166                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
167                    if (Q->ne[1] <= 4) {
168                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
169                        break;
170                    }
171                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
172                    break;
173                }
174                if (cc >= GGML_CUDA_CC_TURING) {
175                    if (Q->ne[1] <= 4) {
176                        if (K->ne[1] <= 16384) {
177                            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
178                            break;
179                        }
180                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
181                        break;
182                    }
183                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
184                    break;
185                }
186                // Volta:
187                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
188            } else if (gqa_ratio % 16 == 0) {
189                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
190            } else {
191                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
192            }
193        } break;
194        default:
195            GGML_ABORT("fatal error");
196            break;
197    }
198}
199
200#define FATTN_VEC_CASE(D, type_K, type_V)                                                                        \
201    {                                                                                                            \
202        const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
203        const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
204        if (Q->ne[0] == (D) && type_K_okay && type_V_okay) {                                                     \
205            ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst);                                      \
206            return;                                                                                              \
207        }                                                                                                        \
208    }                                                                                                            \
209
210#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
211    FATTN_VEC_CASE( 64, type_K, type_V)       \
212    FATTN_VEC_CASE(128, type_K, type_V)       \
213    FATTN_VEC_CASE(256, type_K, type_V)       \
214
215static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
216    ggml_tensor * Q = dst->src[0];
217    ggml_tensor * K = dst->src[1];
218    ggml_tensor * V = dst->src[2];
219
220#ifdef GGML_CUDA_FA_ALL_QUANTS
221    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
222    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
223    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
224    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
225    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
226    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
227
228    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_0)
229    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
230    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
231    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
232    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
233    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
234
235    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_1)
236    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
237    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
238    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
239    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
240    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
241
242    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_0)
243    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
244    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
245    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
246    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
247    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
248
249    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_1)
250    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
251    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
252    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
253    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
254    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
255
256    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q8_0)
257    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
258    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
259    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
260    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
261    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
262#else
263    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
264    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
265    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
266#endif // GGML_CUDA_FA_ALL_QUANTS
267
268    GGML_ABORT("fatal error");
269}
270
271// Best FlashAttention kernel for a specific GPU:
272enum best_fattn_kernel {
273    BEST_FATTN_KERNEL_NONE     =   0,
274    BEST_FATTN_KERNEL_TILE     = 200,
275    BEST_FATTN_KERNEL_VEC      = 100,
276    BEST_FATTN_KERNEL_WMMA_F16 = 300,
277    BEST_FATTN_KERNEL_MMA_F16  = 400,
278};
279
280static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
281#ifndef FLASH_ATTN_AVAILABLE
282    GGML_UNUSED(device); GGML_UNUSED(dst);
283    return BEST_FATTN_KERNEL_NONE;
284#endif// FLASH_ATTN_AVAILABLE
285
286    const ggml_tensor * KQV   = dst;
287    const ggml_tensor * Q     = dst->src[0];
288    const ggml_tensor * K     = dst->src[1];
289    const ggml_tensor * V     = dst->src[2];
290    const ggml_tensor * mask  = dst->src[3];
291
292    const int gqa_ratio = Q->ne[2] / K->ne[2];
293    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
294
295    float max_bias = 0.0f;
296    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
297
298    // The effective batch size for the kernel can be increased by gqa_ratio.
299    // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
300    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
301    for (const ggml_tensor * t : {Q, K, V, mask}) {
302        if (t == nullptr || ggml_is_quantized(t->type)) {
303            continue;
304        }
305        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
306            if (t->nb[i] % 16 != 0) {
307                gqa_opt_applies = false;
308                break;
309            }
310        }
311    }
312
313    const int cc = ggml_cuda_info().devices[device].cc;
314
315    switch (K->ne[0]) {
316        case  40:
317        case  64:
318        case  72:
319        case  80:
320        case  96:
321        case 128:
322        case 112:
323        case 256:
324            if (V->ne[0] != K->ne[0]) {
325                return BEST_FATTN_KERNEL_NONE;
326            }
327            break;
328        case 576:
329            if (V->ne[0] != 512) {
330                return BEST_FATTN_KERNEL_NONE;
331            }
332            if (!gqa_opt_applies) {
333                return BEST_FATTN_KERNEL_NONE;
334            }
335            break;
336        default:
337            return BEST_FATTN_KERNEL_NONE;
338    }
339
340#ifndef GGML_CUDA_FA_ALL_QUANTS
341    if (K->type != V->type) {
342        return BEST_FATTN_KERNEL_NONE;
343    }
344#endif // GGML_CUDA_FA_ALL_QUANTS
345
346    switch (K->type) {
347        case GGML_TYPE_F32:
348        case GGML_TYPE_F16:
349            break;
350        case GGML_TYPE_Q4_1:
351        case GGML_TYPE_Q5_0:
352        case GGML_TYPE_Q5_1:
353#ifndef GGML_CUDA_FA_ALL_QUANTS
354            return BEST_FATTN_KERNEL_NONE;
355#endif // GGML_CUDA_FA_ALL_QUANTS
356        case GGML_TYPE_Q4_0:
357        case GGML_TYPE_Q8_0:
358            break;
359        default:
360            return BEST_FATTN_KERNEL_NONE;
361    }
362
363    if (mask && mask->ne[2] != 1) {
364        return BEST_FATTN_KERNEL_NONE;
365    }
366
367    // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
368    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
369
370    // If Turing tensor cores are available, use them:
371    if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
372        if (can_use_vector_kernel) {
373            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
374                if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
375                    return BEST_FATTN_KERNEL_VEC;
376                }
377            } else {
378                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
379                    if (Q->ne[1] <= 2) {
380                        return BEST_FATTN_KERNEL_VEC;
381                    }
382                } else {
383                    if (Q->ne[1] == 1) {
384                        return BEST_FATTN_KERNEL_VEC;
385                    }
386                }
387            }
388            if (!gqa_opt_applies && Q->ne[1] == 1) {
389                return BEST_FATTN_KERNEL_VEC;
390            }
391        }
392        return BEST_FATTN_KERNEL_MMA_F16;
393    }
394
395    if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
396        int gqa_ratio_eff = 1;
397        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
398        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
399            gqa_ratio_eff *= 2;
400        }
401        if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
402            return BEST_FATTN_KERNEL_VEC;
403        }
404        if (Q->ne[1] * gqa_ratio_eff <= 16) {
405            return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
406        }
407        return BEST_FATTN_KERNEL_MMA_F16;
408    }
409
410    // Use the WMMA kernel if possible:
411    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
412        if (can_use_vector_kernel && Q->ne[1] <= 2) {
413            return BEST_FATTN_KERNEL_VEC;
414        }
415        return BEST_FATTN_KERNEL_WMMA_F16;
416    }
417
418    if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
419        if (can_use_vector_kernel) {
420            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
421                if (Q->ne[1] == 1) {
422                    if (!gqa_opt_applies) {
423                        return BEST_FATTN_KERNEL_VEC;
424                    }
425                }
426            } else {
427                if (Q->ne[1] <= 2) {
428                    return BEST_FATTN_KERNEL_VEC;
429                }
430            }
431        }
432        int gqa_ratio_eff = 1;
433        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
434        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
435            gqa_ratio_eff *= 2;
436        }
437        if (Q->ne[1] * gqa_ratio_eff <= 8) {
438            return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
439        }
440        return BEST_FATTN_KERNEL_MMA_F16;
441    }
442
443    // If there are no tensor cores available, use the generic tile kernel:
444    if (can_use_vector_kernel) {
445        if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
446            if (Q->ne[1] == 1) {
447                if (!gqa_opt_applies) {
448                    return BEST_FATTN_KERNEL_VEC;
449                }
450            }
451        } else {
452            if (Q->ne[1] <= 2) {
453                return BEST_FATTN_KERNEL_VEC;
454            }
455        }
456    }
457    return BEST_FATTN_KERNEL_TILE;
458}
459
460void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461    ggml_cuda_set_device(ctx.device);
462    switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
463        case BEST_FATTN_KERNEL_NONE:
464            GGML_ABORT("fatal error");
465        case BEST_FATTN_KERNEL_TILE:
466            ggml_cuda_flash_attn_ext_tile(ctx, dst);
467            break;
468        case BEST_FATTN_KERNEL_VEC:
469            ggml_cuda_flash_attn_ext_vec(ctx, dst);
470            break;
471        case BEST_FATTN_KERNEL_WMMA_F16:
472            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
473            break;
474        case BEST_FATTN_KERNEL_MMA_F16:
475            ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
476            break;
477    }
478}
479
480bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
481    return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
482}