1#include "common.cuh"
   2#include "cp-async.cuh"
   3#include "mma.cuh"
   4#include "fattn-common.cuh"
   5
   6using namespace ggml_cuda_mma;
   7
   8// Config options for the MMA kernel.
   9// Should not affect results, only speed/register pressure/shared memory use.
  10struct fattn_mma_config {
  11    int  nthreads;       // Number of threads per CUDA block.
  12    int  occupancy;      // Targeted occupancy for the MMA kernel.
  13    int  nbatch_fa;      // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
  14    int  nbatch_K2;      // Number of K half2 values in direction of DKQ to load in parallel.
  15    int  nbatch_V2;      // Number of V half2 values in direction of DV to load in parallel.
  16    int  nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
  17    int  nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
  18    bool Q_in_reg;       // Whether the Q values should be kept permanently in registers.
  19
  20    constexpr __host__ __device__ fattn_mma_config(
  21            int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
  22        nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
  23        nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
  24};
  25
  26#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
  27    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                                                                                       \
  28        static_assert((nthreads_)       % 32 == 0 && (nthreads_)       <= 512, "bad nthreads");                                                                    \
  29        static_assert(                               (occupancy_)      <=   8, "bad occupancy");                                                                   \
  30        static_assert((nbatch_fa_)      % 32 == 0 && (nbatch_fa_)      <= 256, "bad nbatch_fa");                                                                   \
  31        static_assert((nbatch_K2_)      %  4 == 0 && (nbatch_K2_)      <= 512, "bad nbatch_K2");                                                                   \
  32        static_assert((nbatch_V2_)      %  4 == 0 && (nbatch_V2_)      <= 256, "bad nbatch_V2");                                                                   \
  33        static_assert((nbatch_combine_) %  4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine");                                                              \
  34        static_assert((nstages_target_)      >= 1 && (nstages_target_) <=   2, "bad nstages_target");                                                              \
  35        return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)};           \
  36    }                                                                                                                                                              \
  37
  38static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
  39    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 2, true);
  40    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 2, true);
  41    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 2, true);
  42    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 128, 2,  64,  32,  32,  32, 2, true);
  43
  44    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 2, true);
  45    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 2, true);
  46    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 2, true);
  47    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 128, 2,  64,  40,  40,  40, 2, true);
  48
  49    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 2, true);
  50    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 2, true);
  51    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 2, true);
  52    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 128, 2,  64,  48,  48,  48, 2, true);
  53
  54    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 2, true);
  55    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 2, true);
  56    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 2, true);
  57    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2,  64,  56,  56,  56, 2, true);
  58
  59    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 2, true);
  60    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 2, true);
  61    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 2, true);
  62    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2,  64,  64,  64,  64, 2, true);
  63
  64    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 2, true);
  65    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 2, true);
  66    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 2, true);
  67    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  32, 128, 128, 128, 2, true);
  68
  69    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, false);
  70    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256, 128, 1, false);
  71    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
  72    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
  73
  74    return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
  75}
  76
  77static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
  78    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8, 128, 2,  64, 128, 128, 128, 2, true);
  79    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);
  80    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
  81    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
  82
  83    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);
  84    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
  85    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
  86    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
  87
  88    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
  89}
  90
  91static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
  92    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, false);
  93    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256,  64, 1, false);
  94    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128,  64, 1, false);
  95    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128,  64, 1, false);
  96
  97    // TODO tune specifically for Volta
  98    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
  99}
 100
 101static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
 102    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);
 103    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
 104    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
 105
 106    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
 107    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
 108    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
 109
 110    // TODO tune specifically for RDNA
 111    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 112}
 113
 114static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
 115    if (ampere_mma_available(cc)) {
 116        return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 117    }
 118    if (turing_mma_available(cc)) {
 119        return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
 120    }
 121    if (amd_wmma_available(cc)) {
 122        return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
 123    }
 124    GGML_ASSERT(volta_mma_available(cc));
 125    return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
 126}
 127
 128static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
 129#if defined(AMPERE_MMA_AVAILABLE)
 130    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 131#elif defined(TURING_MMA_AVAILABLE)
 132    return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
 133#elif defined(VOLTA_MMA_AVAILABLE)
 134    return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
 135#elif defined(AMD_WMMA_AVAILABLE)
 136    return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
 137#else
 138    GGML_UNUSED_VARS(DKQ, DV, ncols);
 139    return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
 140#endif // defined(AMPERE_MMA_AVAILABLE)
 141}
 142
 143static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
 144    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
 145}
 146
 147static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
 148    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
 149}
 150
 151static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
 152    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
 153}
 154
 155static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
 156    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
 157}
 158
 159static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
 160    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
 161}
 162
 163static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
 164    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
 165}
 166
 167static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
 168    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
 169}
 170
 171static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
 172    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
 173}
 174
 175static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
 176    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
 177}
 178
 179static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
 180    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
 181}
 182
 183static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
 184    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
 185}
 186
 187static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
 188    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
 189}
 190
 191static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
 192    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
 193}
 194
 195static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
 196    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
 197}
 198
 199static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
 200    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
 201}
 202
 203static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
 204    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
 205}
 206
 207static constexpr __device__ int get_cols_per_thread() {
 208#if defined(AMD_WMMA_AVAILABLE)
 209    return 1; // RDNA has a single column.
 210#else
 211    return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
 212#endif // defined(AMD_WMMA_AVAILABLE)
 213}
 214
 215static __host__ int get_cols_per_warp(const int cc) {
 216    if (turing_mma_available(cc) || amd_wmma_available(cc)) {
 217        return 16;
 218    } else {
 219        // Volta
 220        return 32;
 221    }
 222}
 223
 224// ------------------------------------------------------------------------------------------------------------------
 225
 226static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
 227    return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
 228}
 229
 230static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
 231#ifdef CP_ASYNC_AVAILABLE
 232    return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
 233#else
 234    GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
 235    return 0;
 236#endif // CP_ASYNC_AVAILABLE
 237}
 238
 239// ------------------------------------------------------------------------------------------------------------------
 240
 241template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
 242static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 243        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
 244    // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
 245    // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
 246    if constexpr (use_cp_async) {
 247        static_assert(!oob_check, "OOB check not compatible with cp_async");
 248        constexpr int preload = 64;
 249        constexpr int h2_per_chunk = 16/sizeof(half2);
 250        const int chunks_per_row = D2 / h2_per_chunk;
 251
 252        const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
 253
 254        auto load = [&] __device__ (auto n) {
 255            const int stride_k = WARP_SIZE >> n;
 256            const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
 257            const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);
 258            const int stride_i = WARP_SIZE / stride_k;
 259
 260            if (k0_start == k0_stop) {
 261                return;
 262            }
 263
 264#pragma unroll
 265            for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
 266                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 267
 268                if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
 269                    break;
 270                }
 271
 272#pragma unroll
 273                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
 274                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 275
 276                    cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
 277                }
 278            }
 279        };
 280        // 1: max 32*16=512 bytes, 256 half
 281        // 2: max 16*16=256 bytes, 128 half
 282        // 3: max  8*16=128 bytes,  64 half
 283        // 4: max  4*16= 64 bytes,  32 half
 284        // 5: max  2*16= 32 bytes,  16 half
 285        // 6: max  1*16= 16 bytes,   8 half
 286        ggml_cuda_unroll<6>{}(load);
 287    } else {
 288        // TODO use ggml_cuda_memcpy_1
 289        auto load = [&] __device__ (const int n) {
 290            const int stride_k = WARP_SIZE >> n;
 291            const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
 292            const int k0_stop  =                             D2 - D2 % (1*stride_k);
 293            const int stride_i = WARP_SIZE / stride_k;
 294
 295            if (k0_start == k0_stop) {
 296                return;
 297            }
 298
 299#pragma unroll
 300            for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
 301                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 302
 303                if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
 304                    break;
 305                }
 306
 307#pragma unroll
 308                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
 309                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 310
 311                    tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
 312                }
 313            }
 314        };
 315        // 1: max 32* 4=128 bytes,  64 half
 316        // 2: max 16* 4= 64 bytes,  32 half
 317        // 3: max  8* 4= 32 bytes,  16 half
 318        // 4: max  4* 4= 16 bytes,   8 half
 319        ggml_cuda_unroll<4>{}(load);
 320    }
 321}
 322
 323template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
 324static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
 325        const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
 326        const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
 327    if constexpr (use_cp_async) {
 328        static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
 329        static_assert(!oob_check, "OOB check incompatible with cp_async");
 330        constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
 331        constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
 332        constexpr int stride_j = nwarps * cols_per_warp;
 333
 334        const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
 335
 336#pragma unroll
 337        for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
 338            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
 339            const int j_vram = fastmodulo(j0 + j_sram, ne01);
 340
 341            if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
 342                break;
 343            }
 344
 345            const int i = 8 * (threadIdx.x % (nbatch_fa/8));
 346
 347            cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
 348        }
 349    } else if constexpr (oob_check) {
 350#pragma unroll
 351        for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
 352            const int j_sram = j1 + threadIdx.y;
 353            const int j_vram = fastmodulo(j0 + j_sram, ne01);
 354
 355            if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
 356                break;
 357            }
 358
 359#pragma unroll
 360            for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
 361                const int i = i0 + threadIdx.x;
 362
 363                tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
 364            }
 365        }
 366    } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
 367        constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
 368        constexpr int stride_j = nwarps * cols_per_warp;
 369#pragma unroll
 370        for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
 371            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
 372            const int j_vram = fastmodulo(j0 + j_sram, ne01);
 373
 374            if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
 375                break;
 376            }
 377
 378            const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
 379
 380            ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
 381        }
 382    } else {
 383#pragma unroll
 384        for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
 385            const int j_sram = j1 + threadIdx.y;
 386            const int j_vram = fastmodulo(j0 + j_sram, ne01);
 387
 388            if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
 389                break;
 390            }
 391
 392#pragma unroll
 393            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
 394                const int i = i0 + 2*threadIdx.x;
 395
 396                ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
 397            }
 398        }
 399    }
 400}
 401
 402template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
 403    bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
 404    typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
 405static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 406        const float2 * const __restrict__ Q_f2,
 407        const half2  * const __restrict__ K_h2,
 408        const half2  * const __restrict__ V_h2,
 409        const half   * const __restrict__ mask_h,
 410        float2       * const __restrict__ dstk,
 411        float2       * const __restrict__ dstk_fixup,
 412        const float scale,
 413        const float slope,
 414        const float logit_softcap,
 415        const uint3 ne01,
 416        const int ne02,
 417        const int stride_K,
 418        const int stride_V,
 419        const int stride_mask,
 420        half2        * const __restrict__ tile_Q,
 421        half2        * const __restrict__ tile_K,
 422        half2        * const __restrict__ tile_V,
 423        half         * const __restrict__ tile_mask,
 424        T_B_KQ       * const __restrict__ Q_B,
 425        T_C_VKQ      * const __restrict__ VKQ_C,
 426        float        * const __restrict__ KQ_max,
 427        float        * const __restrict__ KQ_rowsum,
 428        const int jt,
 429        const int kb0,
 430        const int k_VKQ_sup) {
 431#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
 432    constexpr int  ncols           = ncols1 * ncols2;
 433    constexpr int  cols_per_warp   = T_B_KQ::I;
 434    constexpr int  cols_per_thread = get_cols_per_thread();
 435    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
 436    constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
 437    constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
 438    constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
 439    constexpr bool Q_in_reg        = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
 440    constexpr int  nstages         = ggml_cuda_fattn_mma_get_nstages  (DKQ, DV, ncols1, ncols2);
 441
 442    constexpr int stride_tile_Q = DKQ/2     + 4;
 443    constexpr int stride_tile_K = nbatch_K2 + 4;
 444
 445    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
 446
 447    const int k_VKQ_0 = kb0 * nbatch_fa;
 448#if defined(TURING_MMA_AVAILABLE)
 449    T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
 450#elif defined(AMD_WMMA_AVAILABLE)
 451    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 452#else // Volta
 453    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 454#endif // defined(TURING_MMA_AVAILABLE)
 455
 456    if constexpr (nstages > 1) {
 457        static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
 458        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
 459        static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
 460        constexpr bool use_cp_async = true;
 461        cp_async_wait_all();
 462        __syncthreads();
 463        flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
 464            (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
 465    } else {
 466        constexpr bool use_cp_async = nstages == 1;
 467        if (ncols2 > 1 || mask_h) {
 468            flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
 469                (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
 470        }
 471    }
 472
 473    // For MLA K and V have the same data.
 474    // Therefore, iterate over K in reverse and later re-use the data if possible.
 475#pragma unroll
 476    for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
 477        const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
 478        const int k0_diff = k0_stop - k0_start;
 479
 480        if constexpr (nstages <= 1) {
 481            constexpr bool use_cp_async = nstages == 1;
 482            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
 483                (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
 484            if (use_cp_async) {
 485                cp_async_wait_all();
 486            }
 487            __syncthreads();
 488        }
 489
 490        // Calculate tile of KQ:
 491        if constexpr (Q_in_reg) {
 492#pragma unroll
 493            for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
 494                const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
 495#pragma unroll
 496                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
 497                    T_A_KQ K_A;
 498                    load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 499                    if constexpr (cols_per_warp == 8) {
 500                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
 501                    } else {
 502                        // Wide version of KQ_C is column-major
 503#if defined(AMD_WMMA_AVAILABLE)
 504                        // RDNA matrix C is column-major.
 505                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
 506#else
 507                        // swap A and B for CUDA.
 508                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
 509#endif // defined(AMD_WMMA_AVAILABLE)
 510                    }
 511                }
 512            }
 513        } else {
 514#pragma unroll
 515            for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
 516                load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
 517
 518#pragma unroll
 519                for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
 520                    const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
 521
 522                    T_A_KQ K_A;
 523                    load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 524
 525                    if constexpr (cols_per_warp == 8) {
 526                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 527                    } else {
 528                        // Wide version of KQ_C is column-major
 529#if defined(AMD_WMMA_AVAILABLE)
 530                        // RDNA matrix C is column-major.
 531                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 532#else
 533                        // swap A and B for CUDA.
 534                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
 535#endif // defined(AMD_WMMA_AVAILABLE)
 536                    }
 537                }
 538            }
 539        }
 540
 541        if constexpr (nstages <= 1) {
 542            __syncthreads(); // Only needed if tile_K == tile_V.
 543        }
 544    }
 545
 546    if (use_logit_softcap) {
 547        constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
 548        static_assert(nbatch_fa % stride == 0, "bad loop size");
 549#pragma unroll
 550        for (int i = 0; i < nbatch_fa/stride; ++i) {
 551#pragma unroll
 552            for (int l = 0; l < T_C_KQ::ne; ++l) {
 553                KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
 554            }
 555        }
 556    }
 557
 558    float KQ_max_new[cols_per_thread];
 559#pragma unroll
 560    for (int col = 0; col < cols_per_thread; ++col) {
 561        KQ_max_new[col] = KQ_max[col];
 562    }
 563    float KQ_rowsum_add[cols_per_thread] = {0.0f};
 564
 565    if constexpr (cols_per_warp == 8) {
 566        if (ncols2 > 1 || mask_h) {
 567#pragma unroll
 568            for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
 569                const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
 570#pragma unroll
 571                for (int l = 0; l < T_C_KQ::ne; ++l) {
 572                    const int i = i0 + T_C_KQ::get_i(l);
 573                    const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
 574
 575                    KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
 576                }
 577            }
 578        }
 579
 580        // Calculate softmax for each KQ column using the current max. value.
 581        // The divisor is stored in KQ_rowsum and will be applied at the end.
 582        static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
 583#pragma unroll
 584        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
 585#pragma unroll
 586            for (int l = 0; l < T_C_KQ::ne; ++l) {
 587                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
 588#if defined(AMD_WMMA_AVAILABLE)
 589                    constexpr int KQ_idx = 0;
 590#else
 591                    // Turing + Volta:
 592                    const int KQ_idx = l % 2;
 593#endif // defined(AMD_WMMA_AVAILABLE)
 594                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
 595                }
 596            }
 597        }
 598
 599        // Values per KQ column are spread across 8 threads:
 600#pragma unroll
 601        for (int col = 0; col < cols_per_thread; ++col) {
 602#pragma unroll
 603            for (int offset = 16; offset >= 4; offset >>= 1) {
 604                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
 605            }
 606        }
 607
 608        static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
 609#pragma unroll
 610        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
 611#pragma unroll
 612            for (int l = 0; l < T_C_KQ::ne; ++l) {
 613                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
 614#if defined(AMD_WMMA_AVAILABLE)
 615                    constexpr int KQ_idx = 0;
 616#else
 617                    // Turing + Volta:
 618                    const int KQ_idx = l % 2;
 619#endif // defined(AMD_WMMA_AVAILABLE)
 620                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
 621                    KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
 622                } else {
 623                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
 624                }
 625            }
 626        }
 627    } else { // not Turing mma or T_B_KQ::I > 8
 628        if (ncols2 > 1 || mask_h) {
 629#pragma unroll
 630            for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
 631                const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
 632#pragma unroll
 633                for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
 634                    const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
 635                    const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
 636
 637                    const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
 638                    KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
 639                    KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
 640                }
 641            }
 642        }
 643
 644        // Calculate softmax for each KQ column using the current max. value.
 645        // The divisor is stored in KQ_rowsum and will be applied at the end.
 646        static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
 647#pragma unroll
 648        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 649#pragma unroll
 650            for (int l = 0; l < T_C_KQ::ne; ++l) {
 651                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
 652#if defined(AMD_WMMA_AVAILABLE)
 653                    constexpr int KQ_idx = 0;
 654#else
 655                    // Turing + Volta:
 656                    const int KQ_idx = (l/2) % 2;
 657#endif // defined(AMD_WMMA_AVAILABLE)
 658                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
 659                }
 660            }
 661        }
 662
 663#pragma unroll
 664        for (int col = 0; col < cols_per_thread; ++col) {
 665#if defined(TURING_MMA_AVAILABLE)
 666            // Values per KQ column are spread across 4 threads:
 667            constexpr int offset_first = 2;
 668            constexpr int offset_last  = 1;
 669#elif defined(AMD_WMMA_AVAILABLE)
 670            // Values per KQ column are spread across 2 threads:
 671            constexpr int offset_first = 16;
 672            constexpr int offset_last  = 16;
 673#else // Volta
 674            // Values per KQ column are spread across 2 threads:
 675            constexpr int offset_first = 2;
 676            constexpr int offset_last  = 2;
 677#endif // defined(TURING_MMA_AVAILABLE)
 678#pragma unroll
 679            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
 680                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
 681            }
 682        }
 683
 684        static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
 685#pragma unroll
 686        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 687#pragma unroll
 688            for (int l = 0; l < T_C_KQ::ne; ++l) {
 689                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
 690#if defined(AMD_WMMA_AVAILABLE)
 691                    constexpr int KQ_idx = 0;
 692#else
 693                    // Turing + Volta:
 694                    const int KQ_idx = (l/2) % 2;
 695#endif // defined(AMD_WMMA_AVAILABLE)
 696                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
 697                    KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
 698                } else {
 699                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
 700                }
 701            }
 702        }
 703    }
 704
 705    {
 706        float KQ_max_scale[cols_per_thread];
 707#pragma unroll
 708        for (int col = 0; col < cols_per_thread; ++col) {
 709            const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
 710            KQ_max_scale[col] = expf(KQ_max_diff);
 711            KQ_max[col] = KQ_max_new[col];
 712
 713            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
 714
 715            // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
 716            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
 717        }
 718
 719#if defined(TURING_MMA_AVAILABLE)
 720        if constexpr (cols_per_warp == 8) {
 721            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 722#pragma unroll
 723            for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 724#pragma unroll
 725                for (int l = 0; l < T_C_VKQ::ne; ++l) {
 726                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
 727                }
 728            }
 729        } else {
 730#pragma unroll
 731            for (int col = 0; col < cols_per_thread; ++col) {
 732                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
 733#pragma unroll
 734                for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
 735#pragma unroll
 736                    for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
 737                        VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
 738                    }
 739                }
 740            }
 741        }
 742#elif defined(AMD_WMMA_AVAILABLE)
 743        const half2 KQ_max_scale_h2 = make_half2(
 744            KQ_max_scale[0], KQ_max_scale[0]);
 745#pragma unroll
 746        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
 747#pragma unroll
 748            for (int l = 0; l < T_C_VKQ::ne; ++l) {
 749                VKQ_C[i].x[l] *= KQ_max_scale_h2;
 750            }
 751        }
 752#else // Volta
 753        const half2 KQ_max_scale_h2 = make_half2(
 754            KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
 755#pragma unroll
 756        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
 757#pragma unroll
 758            for (int l = 0; l < T_C_VKQ::ne; ++l) {
 759                VKQ_C[i].x[l] *= KQ_max_scale_h2;
 760            }
 761        }
 762#endif // defined(TURING_MMA_AVAILABLE)
 763    }
 764
 765    // Convert KQ C tiles into B tiles for VKQ calculation:
 766    T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
 767    static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
 768    if constexpr (cols_per_warp == 8) {
 769#pragma unroll
 770        for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
 771            B[k] = get_transposed(get_half2(KQ_C[k]));
 772        }
 773    } else {
 774        for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
 775            B[k] = get_half2(KQ_C[k]);
 776        }
 777    }
 778
 779    if constexpr (nstages > 1) {
 780        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
 781        // Preload K tile for next iteration:
 782        constexpr bool use_cp_async = true;
 783        cp_async_wait_all();
 784        __syncthreads();
 785        if (!last_iter) {
 786            if (ncols2 > 1 || mask_h) {
 787                flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
 788                    (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
 789            }
 790            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
 791                (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
 792        }
 793    }
 794
 795
 796#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
 797    T_A_VKQ A_identity;
 798    make_identity_mat(A_identity);
 799#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
 800
 801    // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
 802#pragma unroll
 803    for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
 804        static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
 805        const int i0_stop = i0_start + 2*nbatch_V2;
 806        const int i0_diff = i0_stop - i0_start;
 807
 808        if constexpr (nstages <= 1) {
 809            if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
 810                constexpr bool use_cp_async = nstages == 1;
 811                flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
 812                    (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
 813                if (use_cp_async) {
 814                    cp_async_wait_all();
 815                }
 816                __syncthreads();
 817            }
 818        }
 819        const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
 820
 821#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 822        constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
 823#pragma unroll
 824        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
 825            static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
 826#pragma unroll
 827            for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
 828                const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
 829
 830                T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
 831#if defined(LDMATRIX_TRANS_AVAILABLE)
 832                load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
 833#else
 834                // TODO: Try to transpose tile_V when loading gmem to smem.
 835                // Use mma to transpose T_A_VKQ for RDNA.
 836                T_A_VKQ A_trans;
 837                load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
 838                mma(A, A_trans, A_identity);
 839#endif // defined(TURING_MMA_AVAILABLE)
 840                if constexpr (T_B_KQ::I == 8) {
 841                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
 842                } else {
 843                    // Wide version of VKQ_C is column-major.
 844#if defined(AMD_WMMA_AVAILABLE)
 845                    // RDNA matrix C is column-major.
 846                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
 847#else
 848                    // swap A and B for CUDA.
 849                    mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
 850#endif // defined(AMD_WMMA_AVAILABLE)
 851                }
 852            }
 853        }
 854#else // Volta
 855        constexpr int i0_stride = 2*T_C_VKQ::J;
 856#pragma unroll
 857        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
 858            static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
 859            static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
 860#pragma unroll
 861            for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
 862                const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
 863
 864                T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
 865                load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
 866                mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
 867            }
 868        }
 869#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 870
 871        if constexpr (nstages <= 1) {
 872            __syncthreads(); // Only needed if tile_K == tile_V.
 873        }
 874    }
 875#else
 876    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
 877        scale, slope, logit_softcap, ne01, ne02,
 878        stride_K, stride_V, stride_mask,
 879        tile_Q, tile_K, tile_V, tile_mask,
 880        Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
 881    NO_DEVICE_CODE;
 882#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
 883}
 884
 885#if defined(TURING_MMA_AVAILABLE)
 886template<int ncols> struct mma_tile_sizes {
 887    using T_A_KQ  = tile<16,  8, half2>; // row-major
 888    using T_B_KQ  = tile<16,  8, half2>; // column-major
 889    using T_C_KQ  = tile<16, 16, float>; // column-major
 890    using T_A_VKQ = tile<16,  8, half2>; // row-major
 891    using T_B_VKQ = tile<16,  8, half2>; // column-major
 892    using T_C_VKQ = tile<16,  8, half2>; // column-major
 893};
 894template<> struct mma_tile_sizes<8> {
 895    using T_A_KQ  = tile<16,  8, half2>; // row-major
 896    using T_B_KQ  = tile< 8,  8, half2>; // column-major
 897    using T_C_KQ  = tile<16,  8, float>; // row-major
 898    using T_A_VKQ = tile<16,  8, half2>; // row-major
 899    using T_B_VKQ = tile< 8,  8, half2>; // column-major
 900    using T_C_VKQ = tile<16,  4, half2>; // row-major
 901};
 902#elif defined(AMD_WMMA_AVAILABLE)
 903template<int ncols> struct mma_tile_sizes {
 904    using T_A_KQ  = tile<16,  8, half2>; // row-major
 905    using T_B_KQ  = tile<16,  8, half2>; // column-major
 906    using T_C_KQ  = tile<16, 16, float>; // column-major
 907    using T_A_VKQ = tile<16,  8, half2>; // row-major
 908    using T_B_VKQ = tile<16,  8, half2>; // column-major
 909    using T_C_VKQ = tile<16,  8, half2>; // column-major
 910};
 911#else // Volta
 912template<int ncols> struct mma_tile_sizes {
 913    using T_A_KQ  = tile< 8,  4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
 914    using T_B_KQ  = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major
 915    using T_C_KQ  = tile<32,  8, float, DATA_LAYOUT_I_MAJOR>;          // column-major
 916    using T_A_VKQ = tile< 8,  4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
 917    using T_B_VKQ = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major
 918    using T_C_VKQ = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major
 919};
 920#endif // defined(TURING_MMA_AVAILABLE)
 921
 922template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
 923static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 924        const float2 * const __restrict__ Q_f2,
 925        const half2  * const __restrict__ K_h2,
 926        const half2  * const __restrict__ V_h2,
 927        const half   * const __restrict__ mask_h,
 928        const float  * const __restrict__ sinks_f,
 929        float2       * const __restrict__ dstk,
 930        float2       * const __restrict__ dstk_fixup,
 931        const float scale,
 932        const float slope,
 933        const float logit_softcap,
 934        const uint3 ne01,
 935        const int ne02,
 936        const int gqa_ratio,
 937        const int ne11,
 938        const int stride_Q1,
 939        const int stride_Q2,
 940        const int stride_K,
 941        const int stride_V,
 942        const int stride_mask,
 943        const int jt,
 944        const int zt_gqa,
 945        const int kb0_start,
 946        const int kb0_stop) {
 947#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
 948    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 949
 950    constexpr int ncols = ncols1 * ncols2;
 951    using     T_A_KQ    = typename mma_tile_sizes<ncols>::T_A_KQ;
 952    using     T_B_KQ    = typename mma_tile_sizes<ncols>::T_B_KQ;
 953    using     T_C_KQ    = typename mma_tile_sizes<ncols>::T_C_KQ;
 954    using     T_A_VKQ   = typename mma_tile_sizes<ncols>::T_A_VKQ;
 955    using     T_B_VKQ   = typename mma_tile_sizes<ncols>::T_B_VKQ;
 956    using     T_C_VKQ   = typename mma_tile_sizes<ncols>::T_C_VKQ;
 957
 958    constexpr int  cols_per_warp   = T_B_KQ::I;
 959    constexpr int  cols_per_thread = get_cols_per_thread();
 960    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
 961    constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
 962    constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
 963    constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);
 964    constexpr int  nbatch_combine  = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
 965    constexpr bool Q_in_reg        = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols);
 966    constexpr int  nstages         = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2);
 967
 968    if (cols_per_warp > ncols) {
 969        NO_DEVICE_CODE;
 970        return;
 971    }
 972
 973    static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
 974
 975    constexpr int stride_tile_Q = DKQ/2     + 4;
 976    constexpr int stride_tile_K = nbatch_K2 + 4;
 977
 978    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
 979    constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 980
 981    extern __shared__ half2 tile_Q[];
 982    half2 * tile_K    = Q_in_reg              ? tile_Q                             : tile_Q + ncols     * stride_tile_Q;
 983    half2 * tile_V    =           nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
 984    half  * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
 985
 986    T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
 987#if defined(TURING_MMA_AVAILABLE)
 988    T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
 989#elif defined(AMD_WMMA_AVAILABLE)
 990    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 991#else // Volta
 992    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 993#endif // defined(TURING_MMA_AVAILABLE)
 994
 995    float KQ_rowsum[cols_per_thread] = {0.0f};
 996    float KQ_max[cols_per_thread];
 997#pragma unroll
 998    for (int col = 0; col < cols_per_thread; ++col) {
 999        KQ_max[col] = -FLT_MAX/2.0f;
1000    }
1001
1002    // Load Q data into tile_Q, either temporarily or permanently.
1003    // Q in registers is faster, but register pressure is the biggest bottleneck.
1004    // The loading is done with decreasing granularity for D for better memory bandwidth.
1005    const half2 scale_h2 = make_half2(scale, scale);
1006#pragma unroll
1007    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1008        const int k0_start  = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
1009        const int k0_stop   =                             DKQ/2 - (DKQ/2) % (1*stride_k);
1010        const int stride_jc = WARP_SIZE / stride_k;
1011
1012        if (k0_start == k0_stop) {
1013            continue;
1014        }
1015
1016#pragma unroll
1017        for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
1018            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1019
1020            if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
1021                break;
1022            }
1023
1024            const int j = jc / ncols2;
1025            const int c = jc % ncols2;
1026
1027            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
1028#pragma unroll
1029                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1030                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1031
1032                    const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
1033                    tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
1034                }
1035            } else {
1036#pragma unroll
1037                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1038                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1039
1040                    tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
1041                }
1042            }
1043        }
1044    }
1045
1046    __syncthreads();
1047
1048    if (Q_in_reg) {
1049        const int j0 = (threadIdx.y / np) * cols_per_warp;
1050
1051#pragma unroll
1052        for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
1053            load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
1054        }
1055    }
1056
1057    __syncthreads();
1058
1059    int kb0 = kb0_start;
1060
1061    // Preload mask and K data for first iteration when using cp_async with multiple stages:
1062    if constexpr (nstages > 1) {
1063        static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
1064        constexpr bool use_cp_async = true;
1065        constexpr bool oob_check    = false;
1066        constexpr int  k_VKQ_sup    = nbatch_fa;
1067        if (ncols2 > 1 || mask_h) {
1068            flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
1069                (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
1070        }
1071        flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
1072            (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
1073    }
1074
1075    // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
1076    if constexpr (ncols2 == 1) {
1077        constexpr bool oob_check = true;
1078        for (; kb0 < kb0_stop-1; ++kb0) {
1079            constexpr bool last_iter = false;
1080            constexpr int  k_VKQ_sup = nbatch_fa;
1081            flash_attn_ext_f16_iter
1082                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1083                 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1084                (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1085                 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1086                 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1087        }
1088        constexpr bool last_iter = true;
1089        const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
1090        flash_attn_ext_f16_iter
1091            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1092              T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1093            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1094             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1095             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1096    } else {
1097        constexpr bool oob_check = false;
1098        for (; kb0 < kb0_stop-1; ++kb0) {
1099            constexpr bool last_iter = false;
1100            constexpr int  k_VKQ_sup = nbatch_fa;
1101            flash_attn_ext_f16_iter
1102                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1103                 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1104                (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1105                 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1106                 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1107        }
1108        constexpr bool last_iter = true;
1109        constexpr int  k_VKQ_sup = nbatch_fa;
1110        flash_attn_ext_f16_iter
1111            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1112             T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1113            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1114             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1115             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1116    }
1117
1118    // With multi-stage loading there is no __syncthreads at the end of the iter,
1119    //     there can be a race condition on shared memory access for combining/writing back results.
1120    if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
1121        __syncthreads();
1122    }
1123
1124    // Finally, sum up partial KQ rowsums.
1125    {
1126#if defined(TURING_MMA_AVAILABLE)
1127        // The partial sums are spread across 8/4 threads.
1128        constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
1129        constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;
1130#elif defined(AMD_WMMA_AVAILABLE)
1131        // The partial sums are spread across 2 threads.
1132        constexpr int offset_first = 16;
1133        constexpr int offset_last  = 16;
1134#else // Volta
1135        // The partial sums are spread across 2 threads.
1136        constexpr int offset_first = 2;
1137        constexpr int offset_last  = 2;
1138#endif // defined(TURING_MMA_AVAILABLE)
1139#pragma unroll
1140        for (int col = 0; col < cols_per_thread; ++col) {
1141#pragma unroll
1142            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
1143                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
1144            }
1145        }
1146    }
1147
1148    // If attention sinks are used, potentially re-scale if KQ_max is small.
1149    // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
1150    //     so it's being done unconditionally for every thread.
1151    if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
1152        float KQ_max_scale[cols_per_thread];
1153#pragma unroll
1154        for (int col = 0; col < cols_per_thread; ++col) {
1155            const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
1156            const float sink = sinks_f[jc % ncols2];
1157
1158            const float KQ_max_new = fmaxf(KQ_max[col], sink);
1159            const float KQ_max_diff = KQ_max[col] - KQ_max_new;
1160            KQ_max_scale[col] = expf(KQ_max_diff);
1161            KQ_max[col] = KQ_max_new;
1162
1163            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
1164
1165            const float KQ_max_add = expf(sink - KQ_max_new);
1166            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
1167        }
1168
1169#if defined(TURING_MMA_AVAILABLE)
1170        if constexpr (cols_per_warp == 8) {
1171            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
1172#pragma unroll
1173            for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
1174#pragma unroll
1175                for (int l = 0; l < T_C_VKQ::ne; ++l) {
1176                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
1177                }
1178            }
1179        } else {
1180#pragma unroll
1181            for (int col = 0; col < cols_per_thread; ++col) {
1182                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
1183#pragma unroll
1184                for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1185#pragma unroll
1186                    for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
1187                        VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
1188                    }
1189                }
1190            }
1191        }
1192#elif defined(AMD_WMMA_AVAILABLE)
1193        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
1194#pragma unroll
1195        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1196#pragma unroll
1197            for (int l = 0; l < T_C_VKQ::ne; ++l) {
1198                VKQ_C[i].x[l] *= KQ_max_scale_h2;
1199            }
1200        }
1201#else // Volta
1202        const int col = (threadIdx.x / 2) % 2;
1203        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
1204#pragma unroll
1205        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1206#pragma unroll
1207            for (int l = 0; l < T_C_VKQ::ne; ++l) {
1208                VKQ_C[i].x[l] *= KQ_max_scale_h2;
1209            }
1210        }
1211#endif // defined(TURING_MMA_AVAILABLE)
1212    }
1213
1214    // Combine VKQ accumulator values if np > 1.
1215    // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
1216    // So also write VKQ accumulators to shared memory in column-major format if np == 1.
1217
1218    constexpr int tile_stride = nbatch_combine + 4;
1219    static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
1220
1221    if constexpr (cols_per_warp == 8) {
1222        const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
1223        const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
1224        const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
1225
1226        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
1227            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
1228            ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1229        }
1230
1231        __syncthreads();
1232
1233        if (np == 1) {
1234            // No combination is needed, the meta data can be directly written from registers to VRAM.
1235            if (needs_fixup && threadIdx.x < T_B_KQ::I) {
1236                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1237                dstk_fixup_meta[jc_cwm] = KQ_cmr;
1238            }
1239            if (is_fixup && threadIdx.x < T_B_KQ::I) {
1240                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1241                dstk_fixup_meta[jc_cwm] = KQ_cmr;
1242            }
1243        }
1244    } else {
1245        // jc_cwm = jc combine write meta
1246        // KQ_cmr = KQ combine max rowsum
1247        // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
1248#if defined(TURING_MMA_AVAILABLE)
1249        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
1250        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
1251        const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
1252#elif defined(AMD_WMMA_AVAILABLE)
1253        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
1254        const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
1255        const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
1256#else // Volta
1257        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
1258        const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
1259        const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
1260#endif // defined(TURING_MMA_AVAILABLE)
1261
1262        if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
1263            ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1264        }
1265
1266        __syncthreads();
1267
1268        if (np == 1) {
1269            // No combination is needed, the meta data can be directly written from registers to VRAM.
1270            if (needs_fixup && thread_should_write) {
1271                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1272                dstk_fixup_meta[jc_cwm] = KQ_cmr;
1273            }
1274            if (is_fixup && thread_should_write) {
1275                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1276                dstk_fixup_meta[jc_cwm] = KQ_cmr;
1277            }
1278        }
1279    }
1280
1281    if (np > 1 && threadIdx.y % np == 0) {
1282        // Combine the meta data for parallel warps via shared memory.
1283        // Warps with threadIdx.y % np != 0 must NOT return early.
1284        // All threads must return simultaneously to avoid race conditions with work on the next tile.
1285
1286        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
1287
1288        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1289        float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
1290        float2 meta[nmeta];
1291#pragma unroll
1292        for (int imeta = 0; imeta < nmeta; ++imeta) {
1293            meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
1294        }
1295
1296        float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
1297#pragma unroll
1298        for (int imeta = 1; imeta < nmeta; ++imeta) {
1299            KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
1300        }
1301#pragma unroll
1302        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1303            if (offset < WARP_SIZE) {
1304                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
1305            }
1306        }
1307
1308        float KQ_cms[nmeta]; // KQ combine max scale per warp.
1309#pragma unroll
1310        for (int imeta = 0; imeta < nmeta; ++imeta) {
1311            KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
1312        }
1313
1314        float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
1315#pragma unroll
1316        for (int imeta = 1; imeta < nmeta; ++imeta) {
1317            KQ_crs += KQ_cms[imeta]*meta[imeta].y;
1318        }
1319#pragma unroll
1320        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1321            if (offset < WARP_SIZE) {
1322                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
1323            }
1324        }
1325
1326        __syncthreads();
1327
1328        // Write back combined meta data:
1329#pragma unroll
1330        for (int imeta = 0; imeta < nmeta; ++imeta) {
1331            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
1332                // Combined KQ max scale + rowsum.
1333                meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1334            }
1335        }
1336
1337        // Combined KQ max + rowsum.
1338        static_assert(cols_per_warp <= WARP_SIZE);
1339        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1340            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1341            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1342        }
1343        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1344            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1345            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1346        }
1347    } else if (np > 1) {
1348        // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
1349        // Therefore, all other warps also need to execute a __syncthreads().
1350        // Otherwise the points at which warps synchronize with each other would become misaligned.
1351        __syncthreads();
1352    }
1353
1354#pragma unroll
1355    for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
1356        if constexpr (cols_per_warp == 8) {
1357            const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
1358#pragma unroll
1359            for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
1360                const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
1361
1362#pragma unroll
1363                for (int l = 0; l < T_B_KQ::ne; ++l) {
1364                    const int k = k1 + T_B_KQ::get_j(l);
1365
1366                    tile_Q[jc_cwd*tile_stride + k] = B.x[l];
1367                }
1368            }
1369        } else {
1370            const int j0 = threadIdx.y*cols_per_warp;
1371#pragma unroll
1372            for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
1373#pragma unroll
1374                for (int l = 0; l < T_C_VKQ::ne; ++l) {
1375                    const int j = j0 + T_C_VKQ::get_i(l);
1376                    const int k = k1 + T_C_VKQ::get_j(l);
1377
1378                    tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
1379                }
1380            }
1381        }
1382
1383        __syncthreads();
1384
1385        if (np == 1 || threadIdx.y % np == 0) {
1386            // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
1387            // The values after that are for the partial results of the individual blocks.
1388            float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
1389
1390#pragma unroll
1391            for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1392                const int k0_start  = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1393                const int k0_stop   =                             nbatch_combine - nbatch_combine % (1*stride_k);
1394                const int stride_jc = WARP_SIZE / stride_k;
1395
1396                if (k0_start == k0_stop) {
1397                    continue;
1398                }
1399
1400#pragma unroll
1401                for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
1402                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1403
1404                    if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
1405                        break;
1406                    }
1407
1408                    const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
1409
1410                    const int j_dst = jc_dst / ncols2;
1411                    const int c_dst = jc_dst % ncols2;
1412
1413                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
1414                        continue;
1415                    }
1416
1417                    const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
1418#pragma unroll
1419                    for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1420                        const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1421
1422                        float2 dstk_val = make_float2(0.0f, 0.0f);
1423#pragma unroll
1424                        for (int ip = 0; ip < np; ++ip) {
1425                            const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
1426                            const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
1427                            dstk_val.x += dstk_val_add.x*KQ_crs;
1428                            dstk_val.y += dstk_val_add.y*KQ_crs;
1429                        }
1430
1431                        if (!needs_fixup && !is_fixup) {
1432                            const float KQ_rowsum_j = meta_j[1];
1433                            dstk_val.x /= KQ_rowsum_j;
1434                            dstk_val.y /= KQ_rowsum_j;
1435                        }
1436
1437                        if (is_fixup) {
1438                            dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
1439                        } else {
1440                            dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
1441                        }
1442                    }
1443                }
1444            }
1445        }
1446        if (np > 1) {
1447            __syncthreads();
1448        }
1449    }
1450#else
1451    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1452        scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
1453        stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
1454        jt, kb0_start, kb0_stop);
1455    NO_DEVICE_CODE;
1456#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
1457}
1458
1459template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
1460__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
1461static __global__ void flash_attn_ext_f16(
1462        const char * __restrict__ Q,
1463        const char * __restrict__ K,
1464        const char * __restrict__ V,
1465        const char * __restrict__ mask,
1466        const char * __restrict__ sinks,
1467        const int  * __restrict__ KV_max,
1468        float      * __restrict__ dst,
1469        float2     * __restrict__ dst_meta,
1470        const float scale,
1471        const float max_bias,
1472        const float m0,
1473        const float m1,
1474        const uint32_t n_head_log2,
1475        const float logit_softcap,
1476        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
1477                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
1478        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1479                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
1480                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
1481                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
1482                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1483#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
1484
1485    // Skip unused kernel variants for faster compilation:
1486    if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
1487        NO_DEVICE_CODE;
1488        return;
1489    }
1490#ifdef VOLTA_MMA_AVAILABLE
1491    if (ncols1*ncols2 < 32) {
1492        NO_DEVICE_CODE;
1493        return;
1494    }
1495#endif // VOLTA_MMA_AVAILABLE
1496
1497#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1498    if (ncols1*ncols2 > 32) {
1499        NO_DEVICE_CODE;
1500        return;
1501    }
1502#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1503
1504#if defined(AMD_WMMA_AVAILABLE)
1505    if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
1506        NO_DEVICE_CODE;
1507        return;
1508    }
1509#endif // defined(AMD_WMMA_AVAILABLE)
1510
1511    constexpr int ncols     = ncols1 * ncols2;
1512    constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
1513    constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
1514    constexpr int nwarps    = nthreads / WARP_SIZE;
1515
1516    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1517
1518    const int stride_Q1   = nb01 / sizeof(float2);
1519    const int stride_Q2   = nb02 / sizeof(float2);
1520    const int stride_K    = nb11 / sizeof(half2);
1521    const int stride_mask = nb31 / sizeof(half);
1522
1523    const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
1524
1525    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
1526    const int iter_j     = (ne01.z    + (ncols1    - 1)) / ncols1;
1527    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
1528
1529    // kbc == k block continuous, current index in continuous ijk space.
1530    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1531    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1532
1533    // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1534    // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
1535    // In the most general case >2 seams can fall into the same tile.
1536
1537    // kb0 == k start index when in the output tile.
1538    int kb0_start = kbc % iter_k;
1539    int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
1540
1541    while (kbc < kbc_stop && kb0_stop == iter_k) {
1542        // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
1543        const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1544        const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1545        const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1546        const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1547
1548        const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1549
1550        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1551        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
1552        const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
1553            (const half *) (mask + nb33*(sequence % ne33));
1554        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1555
1556        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1557        const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1558
1559        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1560
1561        if (KV_max) {
1562            kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
1563        }
1564        constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1565        if (kb0_start == 0) {
1566            constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1567            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1568                (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1569                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1570        } else {
1571            constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
1572            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1573                (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1574                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1575        }
1576
1577        kbc += iter_k;
1578        kbc -= kbc % iter_k;
1579
1580        kb0_start = 0;
1581        kb0_stop  = min(iter_k, kbc_stop - kbc);
1582    }
1583
1584    if (kbc >= kbc_stop) {
1585        return;
1586    }
1587
1588    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
1589    const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1590    const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1591    const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1592    const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1593
1594    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1595
1596    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1597    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
1598    const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
1599        (const half *) (mask + nb33*(sequence % ne33));
1600    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1601
1602    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1603    const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1604
1605    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1606
1607    if (KV_max) {
1608        kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
1609    }
1610
1611    constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1612    constexpr bool needs_fixup = false;
1613    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1614        (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1615         ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1616#else
1617    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1618        max_bias, m0, m1, n_head_log2, logit_softcap,
1619        ne00, ne01, ne02, ne03,
1620              nb01, nb02, nb03,
1621        ne10, ne11, ne12, ne13,
1622              nb11, nb12, nb13,
1623              nb21, nb22, nb23,
1624              ne31, ne32, ne33,
1625              nb31, nb32, nb33);
1626    NO_DEVICE_CODE;
1627#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
1628}
1629
1630template <int DKQ, int DV, int ncols1, int ncols2>
1631void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1632    const ggml_tensor * KQV = dst;
1633    const int id = ggml_cuda_get_device();
1634    const int cc = ggml_cuda_info().devices[id].cc;
1635
1636    constexpr int ncols = ncols1 * ncols2;
1637
1638    const int  nthreads       = ggml_cuda_fattn_mma_get_nthreads      (DKQ, DV, ncols, cc);
1639    const int  nbatch_fa      = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols, cc);
1640    const int  nbatch_K2      = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols, cc);
1641    const int  nbatch_V2      = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols, cc);
1642    const int  nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
1643    const bool Q_in_reg       = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols, cc);
1644    const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);
1645
1646    const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
1647    const int nwarps        = nthreads / WARP_SIZE;
1648
1649    constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
1650
1651    const size_t nbytes_shared_KV_1stage = nbatch_fa            * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
1652    const size_t nbytes_shared_KV_2stage = nbatch_fa            *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1653    const size_t nbytes_shared_Q         = ncols                * (DKQ/2 + 4)                             * sizeof(half2);
1654    const size_t nbytes_shared_mask      = ncols1               * (nbatch_fa/2 + 4)                       * sizeof(half2);
1655    const size_t nbytes_shared_combine   = nwarps*cols_per_warp * (nbatch_combine + 4)                    * sizeof(half2);
1656
1657    const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1658
1659    const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
1660        std::max(nbytes_shared_Q,  nbytes_shared_KV + nbytes_shared_mask) :
1661                 nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
1662
1663    float logit_softcap;
1664    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1665
1666#if defined(GGML_USE_HIP)
1667    using fattn_kernel_ptr_t = const void*;
1668#else
1669    using fattn_kernel_ptr_t = fattn_kernel_t;
1670#endif // defined(GGML_USE_HIP)
1671    fattn_kernel_t fattn_kernel;
1672    if (logit_softcap == 0.0f) {
1673        constexpr bool use_logit_softcap = false;
1674        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1675
1676#if !defined(GGML_USE_MUSA)
1677        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1678        if (!shared_memory_limit_raised[id]) {
1679            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1680            shared_memory_limit_raised[id] = true;
1681        }
1682#endif // !defined(GGML_USE_MUSA)
1683    } else {
1684        constexpr bool use_logit_softcap = true;
1685        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1686
1687#if !defined(GGML_USE_MUSA)
1688        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1689        if (!shared_memory_limit_raised[id]) {
1690            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1691            shared_memory_limit_raised[id] = true;
1692        }
1693#endif // !defined(GGML_USE_MUSA)
1694    }
1695
1696    launch_fattn<DV, ncols1, ncols2>
1697        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
1698}
1699
1700
1701#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2)                          \
1702    template void ggml_cuda_flash_attn_ext_mma_f16_case                           \
1703    <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1704
1705#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols)   \
1706    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1,  1); \
1707    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2,  2); \
1708    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4,  4); \
1709    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8,  8); \
1710    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
1711
1712DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,   8)
1713DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,   8)
1714DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,   8)
1715DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,   8)
1716DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,   8)
1717DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,   8)
1718
1719DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  16)
1720DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  16)
1721DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  16)
1722DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  16)
1723DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  16)
1724DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  16)
1725
1726DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  32)
1727DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  32)
1728DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  32)
1729DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  32)
1730DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  32)
1731DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  32)
1732
1733DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  64)
1734DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  64)
1735DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  64)
1736DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  64)
1737DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  64)
1738DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
1739
1740// The number of viable configurations for Deepseek is very limited:
1741extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
1742extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
1743extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
1744
1745// For GLM 4.7 Flash
1746extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
1747extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
1748extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);
1749extern DECL_FATTN_MMA_F16_CASE(576, 512,  1, 32);
1750extern DECL_FATTN_MMA_F16_CASE(576, 512,  2, 32);