summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda/fattn.cu
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda/fattn.cu
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/fattn.cu')
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn.cu482
1 files changed, 482 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn.cu b/llama.cpp/ggml/src/ggml-cuda/fattn.cu
new file mode 100644
index 0000000..721edd9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn.cu
@@ -0,0 +1,482 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-mma-f16.cuh"
+#include "fattn-tile.cuh"
+#include "fattn-vec.cuh"
+#include "fattn-wmma-f16.cuh"
+#include "fattn.cuh"
+
+template <int DKQ, int DV, int ncols2>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const ggml_tensor * Q = dst->src[0];
+
+ if constexpr (ncols2 <= 8) {
+ if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+ }
+
+ if constexpr (ncols2 <= 16) {
+ if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+ }
+
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
+}
+
+template <int DKQ, int DV>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
+ // are put into the template specialization without GQA optimizations.
+ bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ for (const ggml_tensor * t : {Q, K, V, mask}) {
+ if (t == nullptr || ggml_is_quantized(t->type)) {
+ continue;
+ }
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (t->nb[i] % 16 != 0) {
+ use_gqa_opt = false;
+ break;
+ }
+ }
+ }
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+ // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
+ if (cc == GGML_CUDA_CC_VOLTA) {
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio > 4) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio > 2) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio > 1) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+}
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ switch (Q->ne[0]) {
+ case 64:
+ GGML_ASSERT(V->ne[0] == 64);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
+ break;
+ case 80:
+ GGML_ASSERT(V->ne[0] == 80);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
+ break;
+ case 96:
+ GGML_ASSERT(V->ne[0] == 96);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
+ break;
+ case 112:
+ GGML_ASSERT(V->ne[0] == 112);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
+ break;
+ case 128:
+ GGML_ASSERT(V->ne[0] == 128);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
+ break;
+ case 256:
+ GGML_ASSERT(V->ne[0] == 256);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
+ break;
+ case 576: {
+ // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ GGML_ASSERT(V->ne[0] == 512);
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
+ GGML_ASSERT(use_gqa_opt);
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ if (gqa_ratio == 20) { // GLM 4.7 Flash
+ if (cc >= GGML_CUDA_CC_DGX_SPARK) {
+ if (Q->ne[1] <= 8) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ if (cc >= GGML_CUDA_CC_BLACKWELL) {
+ if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ if (Q->ne[1] <= 4) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ if (cc >= GGML_CUDA_CC_TURING) {
+ if (Q->ne[1] <= 4) {
+ if (K->ne[1] <= 16384) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ // Volta:
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ } else if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ }
+ } break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+#define FATTN_VEC_CASE(D, type_K, type_V) \
+ { \
+ const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
+ const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
+ if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
+ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
+ return; \
+ } \
+ } \
+
+#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
+ FATTN_VEC_CASE( 64, type_K, type_V) \
+ FATTN_VEC_CASE(128, type_K, type_V) \
+ FATTN_VEC_CASE(256, type_K, type_V) \
+
+static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * Q = dst->src[0];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#else
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+ GGML_ABORT("fatal error");
+}
+
+// Best FlashAttention kernel for a specific GPU:
+enum best_fattn_kernel {
+ BEST_FATTN_KERNEL_NONE = 0,
+ BEST_FATTN_KERNEL_TILE = 200,
+ BEST_FATTN_KERNEL_VEC = 100,
+ BEST_FATTN_KERNEL_WMMA_F16 = 300,
+ BEST_FATTN_KERNEL_MMA_F16 = 400,
+};
+
+static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
+#ifndef FLASH_ATTN_AVAILABLE
+ GGML_UNUSED(device); GGML_UNUSED(dst);
+ return BEST_FATTN_KERNEL_NONE;
+#endif// FLASH_ATTN_AVAILABLE
+
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ // The effective batch size for the kernel can be increased by gqa_ratio.
+ // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
+ bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ for (const ggml_tensor * t : {Q, K, V, mask}) {
+ if (t == nullptr || ggml_is_quantized(t->type)) {
+ continue;
+ }
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (t->nb[i] % 16 != 0) {
+ gqa_opt_applies = false;
+ break;
+ }
+ }
+ }
+
+ const int cc = ggml_cuda_info().devices[device].cc;
+
+ switch (K->ne[0]) {
+ case 40:
+ case 64:
+ case 72:
+ case 80:
+ case 96:
+ case 128:
+ case 112:
+ case 256:
+ if (V->ne[0] != K->ne[0]) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ case 576:
+ if (V->ne[0] != 512) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ default:
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+#ifndef GGML_CUDA_FA_ALL_QUANTS
+ if (K->type != V->type) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+ switch (K->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ break;
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+#ifndef GGML_CUDA_FA_ALL_QUANTS
+ return BEST_FATTN_KERNEL_NONE;
+#endif // GGML_CUDA_FA_ALL_QUANTS
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ break;
+ default:
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+ if (mask && mask->ne[2] != 1) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+ // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
+ const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+ // If Turing tensor cores are available, use them:
+ if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ } else {
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ } else {
+ if (Q->ne[1] == 1) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ if (!gqa_opt_applies && Q->ne[1] == 1) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
+ if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ int gqa_ratio_eff = 1;
+ const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+ while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+ gqa_ratio_eff *= 2;
+ }
+ if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ if (Q->ne[1] * gqa_ratio_eff <= 16) {
+ return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
+ // Use the WMMA kernel if possible:
+ if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
+ if (can_use_vector_kernel && Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ return BEST_FATTN_KERNEL_WMMA_F16;
+ }
+
+ if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (Q->ne[1] == 1) {
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ } else {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ int gqa_ratio_eff = 1;
+ const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+ while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+ gqa_ratio_eff *= 2;
+ }
+ if (Q->ne[1] * gqa_ratio_eff <= 8) {
+ return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
+ // If there are no tensor cores available, use the generic tile kernel:
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (Q->ne[1] == 1) {
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ } else {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ return BEST_FATTN_KERNEL_TILE;
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_set_device(ctx.device);
+ switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
+ case BEST_FATTN_KERNEL_NONE:
+ GGML_ABORT("fatal error");
+ case BEST_FATTN_KERNEL_TILE:
+ ggml_cuda_flash_attn_ext_tile(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_VEC:
+ ggml_cuda_flash_attn_ext_vec(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_WMMA_F16:
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_MMA_F16:
+ ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
+ break;
+ }
+}
+
+bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
+ return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
+}