1#pragma once
 2
 3#include "common.cuh"
 4
 5#if defined(GGML_USE_MUSA)
 6#define GGML_USE_WMMA_FATTN
 7#endif // defined(GGML_USE_MUSA)
 8
 9#if defined(GGML_HIP_ROCWMMA_FATTN)
10#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
11#define GGML_USE_WMMA_FATTN
12#elif defined(CDNA)
13#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
14#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
15#if defined(RDNA3)
16#define GGML_USE_WMMA_FATTN
17#endif // defined(RDNA3)
18#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
19#define GGML_USE_WMMA_FATTN
20#elif defined(RDNA4)
21#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
22#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
23#endif // defined(GGML_HIP_ROCWMMA_FATTN)
24
25// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
26static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
27#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
28    return false;
29#else
30    if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
31        GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
32        return true;
33    } else if (GGML_CUDA_CC_IS_CDNA(cc)){
34#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
35        return true;
36#else
37        return false;
38#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
39    } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
40#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
41        return true;
42#else
43        return false;
44#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
45    } else {
46        return false;
47    }
48#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
49}
50
51void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);