1#include "common.cuh"
 2#include "fattn-tile.cuh"
 3#include "fattn-wmma-f16.cuh"
 4
 5void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 6    const ggml_tensor * K = dst->src[1];
 7    const ggml_tensor * V = dst->src[2];
 8    switch (K->ne[0]) {
 9        case  40: {
10            GGML_ASSERT(V->ne[0] == K->ne[0]);
11            ggml_cuda_flash_attn_ext_tile_case< 40,  40>(ctx, dst);
12        } break;
13        case  64: {
14            GGML_ASSERT(V->ne[0] == K->ne[0]);
15            ggml_cuda_flash_attn_ext_tile_case< 64,  64>(ctx, dst);
16        } break;
17        case  72: {
18            GGML_ASSERT(V->ne[0] == K->ne[0]);
19            ggml_cuda_flash_attn_ext_tile_case< 72,  72>(ctx, dst);
20        } break;
21        case  80: {
22            GGML_ASSERT(V->ne[0] == K->ne[0]);
23            ggml_cuda_flash_attn_ext_tile_case< 80,  80>(ctx, dst);
24        } break;
25        case  96: {
26            GGML_ASSERT(V->ne[0] == K->ne[0]);
27            ggml_cuda_flash_attn_ext_tile_case< 96,  96>(ctx, dst);
28        } break;
29        case 112: {
30            GGML_ASSERT(V->ne[0] == K->ne[0]);
31            ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
32        } break;
33        case 128: {
34            GGML_ASSERT(V->ne[0] == K->ne[0]);
35            ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
36        } break;
37        case 256: {
38            GGML_ASSERT(V->ne[0] == K->ne[0]);
39            ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
40        } break;
41        case 576: {
42            GGML_ASSERT(V->ne[0] == 512);
43            ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
44        } break;
45        default: {
46            GGML_ABORT("Unsupported head size");
47        } break;
48    }
49}