1#include "common.cuh"
   2#include "fattn-common.cuh"
   3#include "fattn-wmma-f16.cuh"
   4
   5// nbatch_fa == number of KQ rows to process per iteration
   6// nbatch_K == number of K columns to load in parallel for KQ calculation
   7
   8// TODO optimize kernel parameters for FP16 NVIDIA (P100)
   9// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
  10
  11// The ROCm compiler cannot handle templating in __launch_bounds__.
  12// As a workaround, define a macro to package the kernel parameters as uint32_t:
  13#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
  14    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                          \
  15        static_assert((nthreads)          <= 512, "bad nthreads");                                    \
  16        static_assert((occupancy)         <=   8, "bad occupancy");                                   \
  17        static_assert((nbatch_fa)         <= 256, "bad nbatch_fa");                                   \
  18        static_assert((nbatch_K)          <= 256, "bad nbatch_K");                                    \
  19        return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23);    \
  20    }                                                                                                 \
  21
  22static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {
  23    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  64,  40)
  24    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  64,  40)
  25    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  64,  40)
  26    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  64,  40)
  27    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  64,  40)
  28
  29    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 2,  64,  64)
  30    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 2,  64,  64)
  31    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 256, 2,  64,  64)
  32    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)
  33    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
  34
  35    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  64,  72)
  36    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  64,  72)
  37    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  64,  72)
  38    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  64,  72)
  39    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  64,  72)
  40
  41    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)
  42    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)
  43    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)
  44    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  64,  40)
  45    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  64,  40)
  46
  47    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  64,  48)
  48    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  64,  48)
  49    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  64,  48)
  50    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  64,  48)
  51    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  64,  48)
  52
  53    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  64,  56)
  54    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  64,  56)
  55    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  64,  56)
  56    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  64,  56)
  57    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  64,  56)
  58
  59    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 2,  64,  64)
  60    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64,  64)
  61    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64,  64)
  62    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64,  64)
  63    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
  64
  65    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 2,  64,  64)
  66    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 2,  64,  64)
  67    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64,  64)
  68    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
  69    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
  70
  71    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
  72    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
  73    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
  74
  75    return 0;
  76}
  77
  78static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {
  79    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
  80    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
  81    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
  82    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
  83    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
  84
  85    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2, 128, 3,  64,  64)
  86    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  32,  64)
  87    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 3,  32,  64)
  88    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)
  89    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
  90
  91    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
  92    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
  93    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
  94    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
  95    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
  96
  97    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
  98    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
  99    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
 100    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
 101    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
 102
 103    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
 104    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
 105    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
 106    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
 107    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
 108
 109    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
 110    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
 111    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
 112    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
 113    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
 114
 115    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2, 128, 3,  64,  64)
 116    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 3,  32, 128)
 117    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 3,  64, 128)
 118    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3,  32, 128)
 119    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
 120
 121    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2, 128, 3,  64,  64)
 122    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 3,  32,  64)
 123    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  32, 256)
 124    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
 125    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
 126
 127    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
 128    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
 129    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
 130
 131    return 0;
 132}
 133
 134static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
 135    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
 136    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
 137    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
 138    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
 139    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
 140    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
 141
 142    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 3,  32,  64)
 143    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  64,  64)
 144    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 2,  32,  64)
 145    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2, 128,  64)
 146    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
 147    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)
 148
 149    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
 150    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
 151    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
 152    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
 153    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
 154    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
 155
 156    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
 157    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
 158    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
 159    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
 160    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
 161    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
 162
 163    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
 164    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
 165    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
 166    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
 167    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
 168    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
 169
 170    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
 171    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
 172    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
 173    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
 174    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
 175    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
 176
 177    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2, 256, 2, 128,  64)
 178    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64, 128)
 179    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64, 128)
 180    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64, 128)
 181    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
 182    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2,  64,  32)
 183
 184    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2, 256, 2, 128,  64)
 185    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 256, 2,  64, 128)
 186    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64, 128)
 187    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
 188    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
 189
 190    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
 191    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
 192    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
 193    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
 194
 195    return 0;
 196}
 197
 198static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
 199    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
 200    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
 201    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
 202    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
 203    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
 204    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
 205
 206    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 8,  32,  64)
 207    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4,  64, 8,  32,  64)
 208    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 5, 128,  64)
 209    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 5, 128,  64)
 210    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)
 211    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)
 212
 213    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
 214    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
 215    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
 216    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
 217    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
 218    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
 219
 220    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
 221    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
 222    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
 223    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
 224    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
 225    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
 226
 227    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
 228    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
 229    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
 230    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
 231    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
 232    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
 233
 234    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
 235    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
 236    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
 237    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
 238    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
 239    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
 240
 241    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 8,  32,  64)
 242    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 8,  64,  64)
 243    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 8,  64,  64)
 244    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
 245    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128,  64)
 246    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3,  64,  64)
 247
 248    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 8,  32,  64)
 249    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 6,  32, 256)
 250    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 128, 6,  32, 256)
 251    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
 252    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
 253
 254    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
 255    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
 256    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
 257    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
 258
 259    return 0;
 260}
 261
 262static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
 263    if (GGML_CUDA_CC_IS_AMD(cc)) {
 264        if (GGML_CUDA_CC_IS_RDNA(cc)) {
 265            return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
 266        }
 267        return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
 268    }
 269    if (fast_fp16_available(cc)) {
 270        return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
 271    }
 272    return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
 273}
 274
 275static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
 276#ifdef GGML_USE_HIP
 277#ifdef RDNA
 278    return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
 279#else
 280    return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
 281#endif // RDNA
 282#else
 283#ifdef FAST_FP16_AVAILABLE
 284    return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
 285#else
 286    return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
 287#endif // FAST_FP16_AVAILABLE
 288#endif // GGML_USE_HIP
 289}
 290
 291static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
 292    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
 293}
 294
 295static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
 296    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
 297}
 298
 299static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
 300    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
 301}
 302
 303static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
 304    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
 305}
 306
 307static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
 308    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
 309}
 310
 311static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
 312    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
 313}
 314
 315static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
 316    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
 317}
 318
 319static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
 320    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
 321}
 322
 323// TODO: deduplicate with mma-f16
 324template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
 325static __device__ __forceinline__ void flash_attn_tile_load_tile(
 326        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
 327    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
 328    constexpr int cpy_ne = cpy_nb / 4;
 329
 330    auto load = [&] __device__ (const int n) {
 331        const int stride_j = warp_size >> n;
 332
 333        if (stride_j == 0) {
 334            return;
 335        }
 336
 337        const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
 338        const int j0_stop  =                             ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
 339        const int stride_i = warp_size / stride_j;
 340
 341        if (j0_start == j0_stop) {
 342            return;
 343        }
 344
 345#pragma unroll
 346        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
 347            const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
 348
 349            if (i0 + nwarps*stride_i <= I || i < I) {
 350#pragma unroll
 351                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
 352                    const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
 353
 354                    const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
 355                    ggml_cuda_memcpy_1<cpy_nb>(
 356                        tile_KV + i*(J/2 + J_padding) + j,
 357                        !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
 358                }
 359            }
 360        }
 361    };
 362    // 1: max 64*16=512 bytes, 512 half
 363    // 2: max 32*16=512 bytes, 256 half
 364    // 3: max 16*16=256 bytes, 128 half
 365    // 4: max  8*16=128 bytes,  64 half
 366    // 5: max  4*16= 64 bytes,  32 half
 367    // 6: max  2*16= 32 bytes,  16 half
 368    // 7: max  1*16= 16 bytes,   8 half
 369    static_assert(J % 8 == 0, "bad J");
 370    static_assert((J/2) % cpy_ne == 0, "bad J");
 371    ggml_cuda_unroll<7>{}(load);
 372}
 373
 374template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
 375static __device__ __forceinline__ void flash_attn_tile_load_tile(
 376        const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
 377    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
 378    constexpr int cpy_ne = cpy_nb / 4;
 379
 380    auto load = [&] __device__ (const int n) {
 381        const int stride_j = warp_size >> n;
 382
 383        if (stride_j == 0) {
 384            return;
 385        }
 386
 387        const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
 388        const int j0_stop  =                             (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
 389        const int stride_i = warp_size / stride_j;
 390
 391        if (j0_start == j0_stop) {
 392            return;
 393        }
 394
 395#pragma unroll
 396        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
 397            const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
 398
 399            if (i0 + nwarps*stride_i <= I || i < I) {
 400#pragma unroll
 401                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
 402                    const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
 403
 404                    const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
 405                    __align__(16) half2 tmp_h2[cpy_ne/2];
 406                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
 407                        tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
 408
 409                    __align__(16) float2 tmp_f2[cpy_ne/2];
 410#pragma unroll
 411                    for (int l = 0; l < cpy_ne/2; ++l) {
 412                        tmp_f2[l] = __half22float2(tmp_h2[l]);
 413                    }
 414                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
 415                }
 416            }
 417        }
 418    };
 419    // 1: max 32*16=512 bytes, 128 float
 420    // 2: max 16*16=256 bytes,  64 float
 421    // 3: max  8*16=128 bytes,  32 float
 422    // 4: max  4*16= 64 bytes,  16 float
 423    // 5: max  2*16= 32 bytes,   8 float
 424    static_assert(J % 8 == 0, "bad J");
 425    static_assert(J % cpy_ne == 0, "bad J");
 426    ggml_cuda_unroll<5>{}(load);
 427}
 428
 429// Function that performs a single iteration in for the KQ matrix multiplication:
 430template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
 431    bool use_logit_softcap, bool oob_check, typename T_vec_dot>
 432static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
 433        T_vec_dot   * const Q_tmp,
 434        const half2 * const __restrict__ K_h2,
 435        T_vec_dot   * const KV_tmp,
 436        const int stride_K2,
 437        const int k_VKQ_0,
 438        const int k_VKQ_sup,
 439        const int k_KQ_0,
 440        float * KQ_acc) {
 441    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
 442    constexpr int cpy_ne = cpy_nb / 4;
 443
 444    constexpr int ncols = ncols1*ncols2;
 445    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
 446    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
 447
 448    flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
 449        (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
 450    __syncthreads();
 451
 452#ifdef FAST_FP16_AVAILABLE
 453    static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
 454#pragma unroll
 455    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
 456        __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
 457        __align__(16) half2 Q_k[cpw][cpy_ne];
 458#else
 459    static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
 460#pragma unroll
 461    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
 462        __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
 463        __align__(16) float Q_k[cpw][cpy_ne];
 464#endif // FAST_FP16_AVAILABLE
 465
 466#pragma unroll
 467        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
 468            const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
 469
 470#ifdef FAST_FP16_AVAILABLE
 471            ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
 472#else
 473            ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K   + cpy_ne) + k_KQ_1]);
 474#endif // FAST_FP16_AVAILABLE
 475        }
 476#pragma unroll
 477        for (int jc0 = 0; jc0 < cpw; ++jc0) {
 478            const int jc = jc0 + (threadIdx.y / np)*cpw;
 479
 480#ifdef FAST_FP16_AVAILABLE
 481            ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
 482#else
 483            ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ    + k_KQ_0   + k_KQ_1]);
 484#endif // FAST_FP16_AVAILABLE
 485        }
 486
 487#pragma unroll
 488        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
 489#pragma unroll
 490            for (int jc0 = 0; jc0 < cpw; ++jc0) {
 491#pragma unroll
 492                for (int k = 0; k < cpy_ne; ++k) {
 493                    ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
 494                }
 495            }
 496        }
 497    }
 498
 499    if (k_KQ_0 + nbatch_K < DKQ) {
 500        __syncthreads(); // Sync not needed on last iteration.
 501    }
 502}
 503
 504// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
 505template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
 506    bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
 507static __device__ __forceinline__ void flash_attn_tile_iter(
 508        T_vec_dot * const Q_tmp,
 509        const half2 * const __restrict__ K_h2,
 510        const half2 * const __restrict__ V_h2,
 511        const half  * const __restrict__ mask,
 512        const uint3 ne01,
 513        const float logit_softcap,
 514        const float slope,
 515        T_KQ      * const KQ,
 516        T_vec_dot * const KV_tmp,
 517        const int stride_K2,
 518        const int stride_V2,
 519        const int stride_mask,
 520        float * const KQ_max,
 521        float * const KQ_sum,
 522        T_acc * const VKQ,
 523        const int k_VKQ_0,
 524        const int k_VKQ_max,
 525        const int col_Q_0) {
 526    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
 527    constexpr int cpy_ne = cpy_nb / 4;
 528
 529    constexpr int ncols = ncols1*ncols2;
 530    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
 531    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
 532
 533    constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
 534
 535    // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.
 536    // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].
 537#ifdef FAST_FP16_AVAILABLE
 538    constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
 539#else
 540    constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
 541#endif // FAST_FP16_AVAILABLE
 542    static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
 543    const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
 544
 545    float KQ_max_new[cpw];
 546#pragma unroll
 547    for (int jc0 = 0; jc0 < cpw; ++jc0) {
 548        KQ_max_new[jc0] = KQ_max[jc0];
 549    }
 550
 551    float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
 552
 553    // KQ = K @ Q matrix multiplication:
 554    constexpr int nbatch_K_last = DKQ % nbatch_K;
 555#pragma unroll
 556    for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
 557        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
 558            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
 559    }
 560    if (nbatch_K_last > 0) {
 561        constexpr int k_KQ_0 = DKQ - nbatch_K_last;
 562        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
 563            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
 564    }
 565
 566    // Apply logit softcap + mask, update KQ_max:
 567#pragma unroll
 568    for (int jc0 = 0; jc0 < cpw; ++jc0) {
 569        const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
 570
 571#pragma unroll
 572        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
 573            const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
 574
 575#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
 576            // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
 577            // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
 578            KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
 579#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
 580
 581            if (use_logit_softcap) {
 582                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
 583            }
 584
 585            if (!oob_check || i_KQ < k_VKQ_sup) {
 586                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
 587                    slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
 588
 589                KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
 590            }
 591        }
 592
 593        KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
 594    }
 595
 596    if constexpr (np == 1) {
 597        __syncthreads();
 598    } else {
 599        static_assert(cpw == 1, "bad cpw");
 600        __shared__ float KQ_max_new_shared[nwarps];
 601        if (threadIdx.x == 0) {
 602            KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];
 603        }
 604        __syncthreads();
 605        KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];
 606        KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
 607    }
 608
 609    // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
 610#pragma unroll
 611    for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
 612#ifdef FAST_FP16_AVAILABLE
 613        __align__(16) half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 614#else
 615        __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 616#endif // FAST_FP16_AVAILABLE
 617
 618#pragma unroll
 619        for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
 620            const int jc = jc0 + jc1;
 621
 622            const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);
 623            KQ_max[jc] = KQ_max_new[jc];
 624
 625            float KQ_sum_add = 0.0f;
 626#pragma unroll
 627            for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
 628                const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
 629                    expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
 630                KQ_sum_add += val;
 631                tmp[i0/(np*warp_size)][jc1] = val;
 632            }
 633            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
 634
 635#ifdef FAST_FP16_AVAILABLE
 636            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
 637#pragma unroll
 638            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
 639                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
 640            }
 641#else
 642#pragma unroll
 643            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
 644                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
 645                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
 646            }
 647#endif // FAST_FP16_AVAILABLE
 648        }
 649
 650#pragma unroll
 651        for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
 652            const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;
 653
 654            ggml_cuda_memcpy_1<sizeof(tmp[0])>(
 655                KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,
 656                tmp[i0/(np*warp_size)]);
 657        }
 658    }
 659
 660    // VKQ = V @ KQ matrix multiplication:
 661    static_assert(DV <= DKQ, "bad DV");
 662    static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
 663    constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
 664    static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
 665    static_assert(nbatch_V % np == 0, "bad nbatch_V");
 666#pragma unroll
 667    for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
 668        flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
 669            (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
 670        __syncthreads();
 671
 672#ifdef FAST_FP16_AVAILABLE
 673#pragma unroll
 674        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
 675            __align__(16) half2 V_k[(DVp/2)/warp_size];
 676            __align__(16) half2 KQ_k[cpw];
 677
 678            constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 679#pragma unroll
 680            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
 681                ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);
 682            }
 683#pragma unroll
 684            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
 685                const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
 686
 687                __align__(16) half tmp[KQ_cs];
 688                ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
 689                    &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
 690#pragma unroll
 691                for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
 692                    KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);
 693                }
 694            }
 695
 696#pragma unroll
 697            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
 698#pragma unroll
 699                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
 700                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];
 701                }
 702            }
 703        }
 704#else
 705#pragma unroll
 706        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
 707            __align__(16) float2 V_k[(DVp/2)/warp_size];
 708            __align__(16) float  KQ_k[cpw];
 709
 710            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 711#pragma unroll
 712            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
 713                ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);
 714            }
 715#pragma unroll
 716            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
 717                const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
 718
 719                ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(
 720                    &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
 721            }
 722
 723#pragma unroll
 724            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
 725#pragma unroll
 726                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
 727                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];
 728                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];
 729                }
 730            }
 731        }
 732#endif // FAST_FP16_AVAILABLE
 733
 734        __syncthreads();
 735    }
 736}
 737
 738template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
 739__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
 740static __global__ void flash_attn_tile(
 741        const char * __restrict__ Q,
 742        const char * __restrict__ K,
 743        const char * __restrict__ V,
 744        const char * __restrict__ mask,
 745        const char * __restrict__ sinks,
 746        const int  * __restrict__ KV_max,
 747        float      * __restrict__ dst,
 748        float2     * __restrict__ dst_meta,
 749        const float scale,
 750        const float max_bias,
 751        const float m0,
 752        const float m1,
 753        const uint32_t n_head_log2,
 754        const float logit_softcap,
 755        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
 756                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
 757        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
 758                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
 759                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
 760                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
 761                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 762#ifdef FLASH_ATTN_AVAILABLE
 763
 764    // Skip unused kernel variants for faster compilation:
 765
 766    if (
 767#ifdef GGML_USE_WMMA_FATTN
 768            (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
 769#endif // GGML_USE_WMMA_FATTN
 770            (use_logit_softcap && !(DV == 128 || DV == 256))
 771    ) {
 772        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
 773            max_bias, m0, m1, n_head_log2, logit_softcap,
 774            ne00, ne01, ne02, ne03,
 775                  nb01, nb02, nb03,
 776            ne10, ne11, ne12, ne13,
 777                  nb11, nb12, nb13,
 778                  nb21, nb22, nb23,
 779                  ne31, ne32, ne33,
 780                  nb31, nb32, nb33);
 781        NO_DEVICE_CODE;
 782        return;
 783    }
 784
 785    static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
 786
 787    constexpr int ncols     = ncols1*ncols2;
 788    constexpr int warp_size = 32;
 789    constexpr int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
 790    constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
 791    constexpr int nbatch_K  = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
 792
 793    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 794
 795    const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.
 796
 797    const int sequence = blockIdx.z / (ne02/ncols2);
 798    const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
 799    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 800    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);
 801    const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
 802    const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
 803
 804    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
 805
 806    const int stride_K2   = nb11 / sizeof(half2);
 807    const int stride_V2   = nb21 / sizeof(half2);
 808    const int stride_mask = nb31 / sizeof(half);
 809
 810    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 811
 812    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
 813    constexpr int cpy_ne = cpy_nb / 4;
 814
 815    constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
 816    constexpr int np  = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
 817    static_assert(cpw == 1 || np == 1, "bad cpw / np");
 818    static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
 819
 820    constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
 821    constexpr int DVp  = (DV  + 2*warp_size - 1) & ~(2*warp_size - 1); // DV  padded to multiple of 2*warp_size.
 822
 823    // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
 824    // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
 825    //     KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
 826    // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
 827    // VKQ == Accumulators in registers for the final VKQ result.
 828#ifdef FAST_FP16_AVAILABLE
 829    __shared__ half2 Q_tmp[ncols * DKQ/2];
 830    __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
 831    __shared__ half  KQ[ncols * nbatch_fa];
 832    __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 833#else
 834    __shared__ float Q_tmp[ncols * DKQ];
 835    __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
 836    __shared__ float KQ[ncols * nbatch_fa];
 837    __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 838#endif // FAST_FP16_AVAILABLE
 839
 840    float KQ_max[cpw];
 841#pragma unroll
 842    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
 843        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
 844    }
 845    float KQ_sum[cpw] = {0.0f};
 846
 847    // Load Q data, convert to FP16 if fast:
 848#pragma unroll
 849    for (int jc0 = 0; jc0 < cpw; ++jc0) {
 850        const int jc = jc0 + (threadIdx.y / np)*cpw;
 851
 852        const int j = jc / ncols2;
 853        const int c = jc % ncols2;
 854
 855        constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
 856
 857#pragma unroll
 858        for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
 859            if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
 860                __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
 861                ggml_cuda_memcpy_1<sizeof(tmp_f)>
 862                    (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
 863                                 + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
 864
 865#pragma unroll
 866                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
 867                    tmp_f[i1] *= scale;
 868                }
 869
 870#ifdef FAST_FP16_AVAILABLE
 871                __align__(16) half2 tmp_h2[cpy_ne_D/2];
 872#pragma unroll
 873                for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
 874                    tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
 875#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
 876                    // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
 877                    // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
 878                    tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);
 879#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
 880                }
 881                ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
 882                    &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
 883                    tmp_h2);
 884#else
 885                ggml_cuda_memcpy_1<sizeof(tmp_f)>(
 886                    &Q_tmp[jc* DKQ    + i0   + (threadIdx.y % np)*(warp_size*cpy_ne_D)   + threadIdx.x* cpy_ne_D],
 887                    tmp_f);
 888#endif // FAST_FP16_AVAILABLE
 889            }
 890        }
 891    }
 892
 893    __syncthreads();
 894
 895    // Main loop over KV cache:
 896    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
 897    if (ncols2 == 1) {
 898        // Branch with out-of-bounds checks.
 899        int k_VKQ_0 = blockIdx.y*nbatch_fa;
 900        while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
 901            constexpr bool oob_check = false;
 902            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
 903                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
 904                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
 905            k_VKQ_0 += gridDim.y*nbatch_fa;
 906        }
 907        if (k_VKQ_0 < k_VKQ_max) {
 908            constexpr bool oob_check = true;
 909            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
 910                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
 911                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
 912        }
 913    } else {
 914        // Branch without out-of-bounds checks.
 915        for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
 916            constexpr bool oob_check = false;
 917            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
 918                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
 919                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
 920        }
 921    }
 922
 923#pragma unroll
 924    for (int jc0 = 0; jc0 < cpw; ++jc0) {
 925        KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
 926    }
 927
 928    if constexpr (np > 1) {
 929        static_assert(cpw == 1, "bad cpw");
 930        static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
 931
 932#ifdef FAST_FP16_AVAILABLE
 933        half2 * VKQ_combine    = (half2 *) KV_tmp;
 934#else
 935        float * VKQ_combine    = (float *) KV_tmp;
 936#endif // FAST_FP16_AVAILABLE
 937        float * KQ_sum_combine = (float *) Q_tmp;
 938
 939        if (threadIdx.y % np != 0) {
 940#ifdef FAST_FP16_AVAILABLE
 941            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
 942#pragma unroll
 943            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
 944                ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);
 945            }
 946#else
 947            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 948#pragma unroll
 949            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
 950                ggml_cuda_memcpy_1<cpy_ne_D*4>(
 951                    &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
 952            }
 953#endif // FAST_FP16_AVAILABLE
 954
 955            if (threadIdx.x == 0) {
 956                KQ_sum_combine[threadIdx.y] = KQ_sum[0];
 957            }
 958
 959            return;
 960        }
 961
 962        __syncthreads();
 963
 964#pragma unroll
 965        for (int ip = 1; ip < np; ++ip) {
 966#ifdef FAST_FP16_AVAILABLE
 967            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
 968#pragma unroll
 969            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
 970                __align__(16) half2 tmp[cpy_ne_D];
 971                ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
 972#pragma unroll
 973                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
 974                    VKQ[i0/warp_size + i1] += tmp[i1];
 975                }
 976            }
 977#else
 978            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 979#pragma unroll
 980            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
 981                __align__(16) float tmp[cpy_ne_D];
 982                ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
 983#pragma unroll
 984                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
 985                    ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
 986                }
 987            }
 988#endif // FAST_FP16_AVAILABLE
 989
 990            KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];
 991        }
 992    }
 993
 994    // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
 995    if (sinks && blockIdx.y == 0) {
 996#pragma unroll
 997        for (int jc0 = 0; jc0 < cpw; ++jc0) {
 998            const int jc = jc0 + (threadIdx.y/np)*cpw;
 999            const float sink = ((const float *) sinks)[head0 + jc % ncols2];
1000
1001            float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);
1002            const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);
1003            KQ_max[jc0] = KQ_max_new_j;
1004
1005            const float val = expf(sink - KQ_max[jc0]);
1006            KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
1007
1008#ifdef FAST_FP16_AVAILABLE
1009            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
1010#pragma unroll
1011            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
1012                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
1013            }
1014#else
1015#pragma unroll
1016            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
1017                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
1018                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
1019            }
1020#endif // FAST_FP16_AVAILABLE
1021        }
1022    }
1023
1024    // Write back results:
1025#pragma unroll
1026    for (int jc0 = 0; jc0 < cpw; ++jc0) {
1027        const int jc = jc0 + (threadIdx.y/np)*cpw;
1028
1029        const int j = jc / ncols2;
1030        const int c = jc % ncols2;
1031
1032        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
1033            return;
1034        }
1035
1036        const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
1037
1038        const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1039
1040#ifdef FAST_FP16_AVAILABLE
1041        constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
1042#pragma unroll
1043        for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
1044            __align__(16) float2 tmp[cpy_ne_D];
1045#pragma unroll
1046            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
1047                tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1048                tmp[i1].x *= scale;
1049                tmp[i1].y *= scale;
1050            }
1051            if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
1052                ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
1053            }
1054        }
1055#else
1056        constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
1057#pragma unroll
1058        for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
1059            if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1060#pragma unroll
1061                for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1062                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1063                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1064                }
1065                ggml_cuda_memcpy_1<cpy_ne_D*4>(
1066                    &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
1067                    &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
1068            }
1069        }
1070#endif // FAST_FP16_AVAILABLE
1071
1072        if (gridDim.y != 1 && threadIdx.x == 0) {
1073            dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
1074        }
1075    }
1076#else
1077    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1078        max_bias, m0, m1, n_head_log2, logit_softcap,
1079        ne00, ne01, ne02, ne03,
1080              nb01, nb02, nb03,
1081        ne10, ne11, ne12, ne13,
1082              nb11, nb12, nb13,
1083              nb21, nb22, nb23,
1084              ne31, ne32, ne33,
1085              nb31, nb32, nb33);
1086    NO_DEVICE_CODE;
1087#endif // FLASH_ATTN_AVAILABLE
1088}
1089
1090template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
1091static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1092    const ggml_tensor * Q = dst->src[0];
1093
1094    const int id        = ggml_cuda_get_device();
1095    const int cc        = ggml_cuda_info().devices[id].cc;
1096    const int warp_size = 32;
1097
1098    constexpr size_t nbytes_shared = 0;
1099
1100#ifdef GGML_USE_HIP
1101    if constexpr (DV <= 128) {
1102        if (Q->ne[1] > 32/ncols2) {
1103            constexpr int cols_per_block = 64;
1104            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1105            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1106            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1107            launch_fattn<DV, cols_per_block/ncols2, ncols2>
1108                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1109            return;
1110        }
1111    }
1112#endif // GGML_USE_HIP
1113
1114#ifndef GGML_USE_HIP
1115    if constexpr (DV <= 256)
1116#endif // GGML_USE_HIP
1117    {
1118        if (Q->ne[1] > 16/ncols2) {
1119            constexpr int cols_per_block = 32;
1120            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1121            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1122            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1123            launch_fattn<DV, cols_per_block/ncols2, ncols2>
1124                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1125            return;
1126        }
1127    }
1128
1129    if (Q->ne[1] > 8/ncols2) {
1130        constexpr int cols_per_block = 16;
1131        const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1132        const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1133        fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1134        launch_fattn<DV, cols_per_block/ncols2, ncols2>
1135            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1136        return;
1137    }
1138
1139    if constexpr (ncols2 <= 8) {
1140        if (Q->ne[1] > 4/ncols2) {
1141            constexpr int cols_per_block = 8;
1142            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1143            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1144            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1145            launch_fattn<DV, cols_per_block/ncols2, ncols2>
1146                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1147            return;
1148        }
1149    }
1150
1151    if constexpr (ncols2 <= 4) {
1152        if (Q->ne[1] > 2/ncols2) {
1153            constexpr int cols_per_block = 4;
1154            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1155            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1156            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1157            launch_fattn<DV, cols_per_block/ncols2, ncols2>
1158                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1159            return;
1160        }
1161    }
1162
1163    if constexpr (ncols2 <= 2) {
1164        constexpr int cols_per_block = 2;
1165        const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1166        const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1167        fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1168        launch_fattn<DV, cols_per_block/ncols2, ncols2>
1169            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1170        return;
1171    }
1172
1173    GGML_ABORT("fatal error");
1174}
1175
1176template <int DKQ, int DV, bool use_logit_softcap>
1177static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1178    const ggml_tensor * KQV  = dst;
1179    const ggml_tensor * Q    = dst->src[0];
1180    const ggml_tensor * K    = dst->src[1];
1181    const ggml_tensor * mask = dst->src[3];
1182
1183    float max_bias = 0.0f;
1184    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
1185
1186    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
1187    const int gqa_ratio = Q->ne[2] / K->ne[2];
1188
1189    const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
1190    const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
1191    const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
1192
1193    if constexpr (DV == 512) {
1194        if (use_gqa_opt && gqa_ratio % 16 == 0) {
1195            launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1196            return;
1197        }
1198        if (use_gqa_opt && gqa_ratio % 4 == 0) {
1199            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
1200            return;
1201        }
1202    }
1203
1204    if constexpr (DV <= 256) {
1205        if (use_gqa_opt && gqa_ratio % 8 == 0) {
1206            launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
1207            return;
1208        }
1209
1210        if (use_gqa_opt && gqa_ratio % 4 == 0) {
1211            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
1212            return;
1213        }
1214
1215        if (use_gqa_opt && gqa_ratio % 2 == 0) {
1216            launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
1217            return;
1218        }
1219
1220        launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
1221        return;
1222    }
1223    GGML_ABORT("fatal error");
1224}
1225
1226template <int DKQ, int DV>
1227void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1228    const ggml_tensor * KQV = dst;
1229
1230    float logit_softcap;
1231    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1232
1233    if (logit_softcap == 0.0f) {
1234        constexpr bool use_logit_softcap = false;
1235        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1236    } else {
1237        constexpr bool use_logit_softcap = true;
1238        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1239    }
1240}
1241
1242void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1243
1244#define DECL_FATTN_TILE_CASE(DKQ, DV)                             \
1245    template void ggml_cuda_flash_attn_ext_tile_case              \
1246    <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1247
1248extern DECL_FATTN_TILE_CASE( 40,  40);
1249extern DECL_FATTN_TILE_CASE( 64,  64);
1250extern DECL_FATTN_TILE_CASE( 72,  72);
1251extern DECL_FATTN_TILE_CASE( 80,  80);
1252extern DECL_FATTN_TILE_CASE( 96,  96);
1253extern DECL_FATTN_TILE_CASE(112, 112);
1254extern DECL_FATTN_TILE_CASE(128, 128);
1255extern DECL_FATTN_TILE_CASE(256, 256);
1256extern DECL_FATTN_TILE_CASE(576, 512);