1#pragma once
   2
   3#include "common.cuh"
   4#include "vecdotq.cuh"
   5#include "mma.cuh"
   6
   7#include <climits>
   8#include <cstdint>
   9
  10using namespace ggml_cuda_mma;
  11
  12#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
  13#define MMQ_ITER_K 256
  14#define MMQ_ITER_K_MXFP4_FP4    512
  15#define MMQ_NWARPS 8
  16
  17typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
  18typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
  19typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
  20    float * __restrict__ dst, const int stride, const int i_max, const int j_max);
  21
  22enum mmq_q8_1_ds_layout {
  23    MMQ_Q8_1_DS_LAYOUT_D4,
  24    MMQ_Q8_1_DS_LAYOUT_DS4,
  25    MMQ_Q8_1_DS_LAYOUT_D2S6,
  26};
  27
  28struct block_q8_1_mmq {
  29    // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
  30    // The y float data is first grouped as blocks of 128 values.
  31    // These blocks are then treated as individual data values and transposed.
  32    //
  33    // To avoid shared memory bank conflicts each block is padded with 16 bytes.
  34    // This padding is also used to store block scales/partial sums.
  35    // The scales multiplied with the quantized data are equal to the unquantized values.
  36    // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
  37    //     and are only needed for performance reasons.
  38    //
  39    // The exact data stored depends on the x data type.
  40    union {
  41        float d4[4];    // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
  42        half2 ds4[4];   // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
  43        half  d2s6[8];  // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
  44                        //     stored as d0,d1,s1,s2,s3,s4,s5
  45    };
  46    int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
  47};
  48
  49struct block_fp4_mmq {
  50    uint32_t d4[4];       // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
  51    int8_t   qs[4 * 32];  // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
  52};
  53
  54static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  55static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
  56static_assert(sizeof(block_fp4_mmq)  == sizeof(block_q8_1_mmq),    "Unexpected block_fp4_mmq size");
  57
  58static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
  59    switch (type_x) {
  60        case GGML_TYPE_Q4_0:
  61        case GGML_TYPE_Q4_1:
  62            return MMQ_Q8_1_DS_LAYOUT_DS4;
  63        case GGML_TYPE_Q5_0:
  64            return MMQ_Q8_1_DS_LAYOUT_D4;
  65        case GGML_TYPE_Q5_1:
  66            return MMQ_Q8_1_DS_LAYOUT_DS4;
  67        case GGML_TYPE_Q8_0:
  68            return MMQ_Q8_1_DS_LAYOUT_D4;
  69        case GGML_TYPE_MXFP4:
  70            return MMQ_Q8_1_DS_LAYOUT_D4;
  71        case GGML_TYPE_Q2_K:
  72            return MMQ_Q8_1_DS_LAYOUT_D2S6;
  73        case GGML_TYPE_Q3_K:
  74            return MMQ_Q8_1_DS_LAYOUT_D4;
  75        case GGML_TYPE_Q4_K:
  76        case GGML_TYPE_Q5_K:
  77            return MMQ_Q8_1_DS_LAYOUT_DS4;
  78        case GGML_TYPE_Q6_K:
  79        case GGML_TYPE_IQ2_XXS:
  80        case GGML_TYPE_IQ2_XS:
  81        case GGML_TYPE_IQ2_S:
  82        case GGML_TYPE_IQ3_XXS:
  83        case GGML_TYPE_IQ3_S:
  84            return MMQ_Q8_1_DS_LAYOUT_D4;
  85        case GGML_TYPE_IQ1_S:
  86            return MMQ_Q8_1_DS_LAYOUT_DS4;
  87        case GGML_TYPE_IQ4_XS:
  88        case GGML_TYPE_IQ4_NL:
  89            return MMQ_Q8_1_DS_LAYOUT_D4;
  90        default:
  91            GGML_ABORT("fatal error");
  92            break;
  93    }
  94}
  95
  96struct tile_x_sizes {
  97    int qs;
  98    int dm;
  99    int sc;
 100};
 101
 102static int get_mmq_x_max_host(const int cc) {
 103    return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
 104        GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
 105#ifdef GGML_CUDA_FORCE_MMQ
 106            128                     : 64;
 107#else
 108            MMQ_DP4A_MAX_BATCH_SIZE : 64;
 109#endif // GGML_CUDA_FORCE_MMQ
 110}
 111
 112static constexpr __device__ int get_mmq_x_max_device() {
 113#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 114    return 128;
 115#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 116
 117#if defined(GGML_USE_HIP)
 118    return 64;
 119#else // defined(GGML_USE_HIP)
 120
 121#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 122#ifdef GGML_CUDA_FORCE_MMQ
 123    return 128;
 124#else // GGML_CUDA_FORCE_MMQ
 125    return MMQ_DP4A_MAX_BATCH_SIZE;
 126#endif // GGML_CUDA_FORCE_MMQ
 127#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 128    return 64;
 129#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 130
 131#endif // defined(GGML_USE_HIP)
 132#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 133}
 134
 135static int get_mmq_y_host(const int cc) {
 136    return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
 137        ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
 138}
 139
 140static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
 141#if defined(BLACKWELL_MMA_AVAILABLE)
 142    return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
 143#else
 144    return MMQ_ITER_K;
 145#endif // defined(BLACKWELL_MMA_AVAILABLE)
 146}
 147
 148static constexpr __device__ int get_mmq_y_device() {
 149#if defined(GGML_USE_HIP)
 150#if defined(RDNA1)
 151    return 64;
 152#else
 153    return 128;
 154#endif // defined RDNA1
 155#else
 156#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 157    return 128;
 158#else
 159    return 64;
 160#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 161#endif // defined(GGML_USE_HIP)
 162}
 163
 164// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
 165// The K dimension of the tiles has either,
 166// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
 167// 32 bit elements for the quantized data (does not include scales).
 168// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
 169// The final tile size in K direction is padded to avoid shared memory bank conflicts,
 170// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
 171#define MMQ_TILE_NE_K 32
 172
 173#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0   + mmq_y/QI4_0,     0}
 174#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1   + mmq_y/QI4_1,     0}
 175#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
 176#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
 177#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
 178#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K         + mmq_y,           0}
 179#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y,                                         mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
 180#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K,                     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
 181#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K   + mmq_y/QI5_K,     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
 182#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K   + mmq_y/QI6_K,     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
 183
 184static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
 185    switch (type) {
 186        case GGML_TYPE_Q4_0:    return MMQ_DP4A_TXS_Q4_0;
 187        case GGML_TYPE_Q4_1:    return MMQ_DP4A_TXS_Q4_1;
 188        case GGML_TYPE_Q5_0:    return MMQ_DP4A_TXS_Q8_0;
 189        case GGML_TYPE_Q5_1:    return MMQ_DP4A_TXS_Q8_1;
 190        case GGML_TYPE_Q8_0:    return MMQ_DP4A_TXS_Q8_0;
 191        case GGML_TYPE_MXFP4:   return MMQ_DP4A_TXS_Q8_1;
 192        case GGML_TYPE_Q2_K:    return MMQ_DP4A_TXS_Q2_K;
 193        case GGML_TYPE_Q3_K:    return MMQ_DP4A_TXS_Q3_K;
 194        case GGML_TYPE_Q4_K:    return MMQ_DP4A_TXS_Q4_K;
 195        case GGML_TYPE_Q5_K:    return MMQ_DP4A_TXS_Q5_K;
 196        case GGML_TYPE_Q6_K:    return MMQ_DP4A_TXS_Q6_K;
 197        case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
 198        case GGML_TYPE_IQ2_XS:  return MMQ_DP4A_TXS_Q8_0_16;
 199        case GGML_TYPE_IQ2_S:   return MMQ_DP4A_TXS_Q8_0_16;
 200        case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
 201        case GGML_TYPE_IQ3_S:   return MMQ_DP4A_TXS_Q8_0;
 202        case GGML_TYPE_IQ1_S:   return MMQ_DP4A_TXS_Q8_0;
 203        case GGML_TYPE_IQ4_XS:  return MMQ_DP4A_TXS_Q8_0;
 204        case GGML_TYPE_IQ4_NL:  return MMQ_DP4A_TXS_Q8_0;
 205        default:                return tile_x_sizes{0, 0, 0};
 206    }
 207}
 208
 209#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
 210#define MMQ_MMA_TILE_X_K_FP4  (2*MMQ_TILE_NE_K + 8                                       + 4)
 211#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
 212#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)
 213#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)
 214#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K   + MMQ_TILE_NE_K/8 + 7)
 215
 216static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
 217static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
 218static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
 219static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 220static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
 221static_assert(MMQ_MMA_TILE_X_K_FP4  % 8 == 4, "Wrong padding.");
 222static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
 223
 224static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 225    switch (type) {
 226        case GGML_TYPE_Q4_0:    return MMQ_MMA_TILE_X_K_Q8_0;
 227        case GGML_TYPE_Q4_1:    return MMQ_MMA_TILE_X_K_Q8_1;
 228        case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;
 229        case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;
 230        case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
 231        // tile sizes are the same for Q8_1 and FP4 for blackwell
 232        case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;
 233        case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
 234        case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
 235        case GGML_TYPE_Q4_K:    return MMQ_MMA_TILE_X_K_Q8_1;
 236        case GGML_TYPE_Q5_K:    return MMQ_MMA_TILE_X_K_Q8_1;
 237        case GGML_TYPE_Q6_K:    return MMQ_MMA_TILE_X_K_Q6_K;
 238        case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
 239        case GGML_TYPE_IQ2_XS:  return MMQ_MMA_TILE_X_K_Q3_K;
 240        case GGML_TYPE_IQ2_S:   return MMQ_MMA_TILE_X_K_Q3_K;
 241        case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
 242        case GGML_TYPE_IQ3_S:   return MMQ_MMA_TILE_X_K_Q8_0;
 243        case GGML_TYPE_IQ1_S:   return MMQ_MMA_TILE_X_K_Q8_0;
 244        case GGML_TYPE_IQ4_XS:  return MMQ_MMA_TILE_X_K_Q8_0;
 245        case GGML_TYPE_IQ4_NL:  return MMQ_MMA_TILE_X_K_Q8_0;
 246        default:                return 0;
 247    }
 248}
 249
 250// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
 251#define MMQ_TILE_Y_K     (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
 252#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
 253
 254static int mmq_get_granularity_host(const int mmq_x, const int cc) {
 255    if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
 256        return mmq_x >= 128 ? 32 : 16;
 257    } else if (turing_mma_available(cc) && mmq_x >= 48) {
 258        return 16;
 259    } else {
 260        return 8;
 261    }
 262}
 263
 264#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 265static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
 266    return mmq_x >= 128 ? 32 : 16;
 267}
 268#elif defined(TURING_MMA_AVAILABLE)
 269static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
 270    return mmq_x >= 48 ? 16 : 8;
 271}
 272#else
 273static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
 274    return 8;
 275}
 276#endif // AMD_MFMA_AVAILABLE
 277
 278#if defined(GGML_USE_HIP)
 279static int mmq_get_nwarps_host(const int cc, const int warp_size) {
 280    return amd_mfma_available(cc) ? 8 : 256/warp_size;
 281}
 282#else
 283static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
 284    return 256/warp_size;
 285}
 286#endif // (GGML_USE_HIP)
 287
 288static constexpr __device__ int mmq_get_nwarps_device() {
 289#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 290    return 8;
 291#else
 292    return 256/ggml_cuda_get_physical_warp_size();
 293#endif // AMD_MFMA_AVAILABLE
 294}
 295
 296// ------------------------------------------------------------
 297
 298template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
 299    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 300    constexpr int nwarps = mmq_get_nwarps_device();
 301    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 302
 303#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 304    int   * x_qs = (int   *)  x_tile;
 305    float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
 306#else
 307    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
 308    int   * x_qs = (int   *)  x_tile;
 309    float * x_df = (float *) (x_qs + txs.qs);
 310#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 311
 312    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
 313    constexpr int nrows = warp_size / threads_per_row;
 314    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 315    const int kbx  = txi / QI4_0;
 316    const int kqsx = txi % QI4_0;
 317
 318#pragma unroll
 319    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
 320        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 321
 322        if (need_check) {
 323            i = min(i, i_max);
 324        }
 325
 326        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
 327        const int qs0 = get_int_b2(bxi->qs, kqsx);
 328
 329#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 330        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
 331        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 332#else
 333        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
 334#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 335    }
 336
 337    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
 338    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
 339    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 340
 341#pragma unroll
 342    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
 343        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 344
 345        if (need_check) {
 346            i = min(i, i_max);
 347        }
 348
 349        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 350
 351#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 352        x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 353#else
 354        x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
 355#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 356    }
 357}
 358
 359template <int mmq_x, int mmq_y>
 360static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 361    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 362    constexpr int nwarps = mmq_get_nwarps_device();
 363    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 364
 365    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
 366    const int   * x_qs = (const int   *) x;
 367    const float * x_df = (const float *) x_qs + txs.qs;
 368    const int   * y_qs = (const int   *) y + 4;
 369    const half2 * y_ds = (const half2 *) y;
 370
 371// #pragma unroll
 372    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
 373        const int k0 = k00 + k01;
 374
 375#pragma unroll
 376        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 377            const int j = j0 + threadIdx.y;
 378
 379#pragma unroll
 380            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
 381                const int i = i0 + threadIdx.x;
 382
 383                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
 384
 385                int u[2*VDR_Q4_0_Q8_1_MMQ];
 386
 387#pragma unroll
 388                for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
 389                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
 390                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
 391                }
 392
 393                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
 394                    (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
 395                     x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
 396            }
 397        }
 398    }
 399}
 400
 401template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
 402    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 403    constexpr int nwarps = mmq_get_nwarps_device();
 404    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 405
 406#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 407    int   * x_qs = (int   *)  x_tile;
 408    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 409#else
 410    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
 411    int   * x_qs = (int   *)  x_tile;
 412    half2 * x_dm = (half2 *) (x_qs + txs.qs);
 413#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
 414
 415    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
 416    constexpr int nrows = warp_size / threads_per_row;
 417    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 418    const int kbx  = txi / QI4_1;
 419    const int kqsx = txi % QI4_1;
 420
 421#pragma unroll
 422    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
 423        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 424
 425        if (need_check) {
 426            i = min(i, i_max);
 427        }
 428
 429        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
 430        const int qs0 = get_int_b4(bxi->qs, kqsx);
 431
 432#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 433        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
 434        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 435#else
 436        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
 437#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 438    }
 439
 440    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
 441    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
 442    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 443
 444#pragma unroll
 445    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
 446        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 447
 448        if (need_check) {
 449            i = min(i, i_max);
 450        }
 451
 452        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 453
 454#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 455        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 456#else
 457        x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
 458#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 459    }
 460}
 461
 462template <int mmq_x, int mmq_y>
 463static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 464    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 465    constexpr int nwarps = mmq_get_nwarps_device();
 466    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 467
 468    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
 469    const int   * x_qs = (const int   *) x;
 470    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
 471    const int   * y_qs = (const int   *) y + 4;
 472    const half2 * y_ds = (const half2 *) y;
 473
 474// #pragma unroll
 475    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
 476        const int k0 = k00 + k01;
 477
 478#pragma unroll
 479        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 480            const int j = j0 + threadIdx.y;
 481
 482#pragma unroll
 483            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
 484                const int i = i0 + threadIdx.x;
 485
 486                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
 487
 488                int u[2*VDR_Q4_1_Q8_1_MMQ];
 489
 490#pragma unroll
 491                for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
 492                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
 493                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
 494                }
 495
 496                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
 497                    (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
 498                     x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
 499            }
 500        }
 501    }
 502}
 503
 504template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
 505    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 506    constexpr int nwarps = mmq_get_nwarps_device();
 507    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 508
 509#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 510    int   * x_qs = (int   *)  x_tile;
 511    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 512#else
 513    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
 514    int   * x_qs = (int   *)  x_tile;
 515    float * x_df = (float *) (x_qs + txs.qs);
 516#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 517
 518    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
 519    constexpr int nrows = warp_size / threads_per_row;
 520    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 521    const int kbx  = txi / QI5_0;
 522    const int kqsx = txi % QI5_0;
 523
 524#pragma unroll
 525    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
 526        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 527
 528        if (need_check) {
 529            i = min(i, i_max);
 530        }
 531
 532        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
 533
 534        const int ql = get_int_b2(bxi->qs, kqsx);
 535        const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
 536
 537        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
 538        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
 539        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
 540        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
 541        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
 542        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
 543
 544        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
 545        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
 546        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
 547        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
 548        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
 549        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 550
 551#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 552        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
 553        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 554#else
 555        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
 556        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 557#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 558    }
 559
 560    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
 561    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
 562    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 563
 564#pragma unroll
 565    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
 566        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 567
 568        if (need_check) {
 569            i = min(i, i_max);
 570        }
 571
 572        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 573
 574#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 575        x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 576#else
 577        x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
 578#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
 579    }
 580}
 581
 582template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
 583    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 584    constexpr int nwarps = mmq_get_nwarps_device();
 585    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 586
 587#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 588    int   * x_qs = (int   *)  x_tile;
 589    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 590#else
 591    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
 592    int   * x_qs = (int   *)  x_tile;
 593    half2 * x_dm = (half2 *) (x_qs + txs.qs);
 594#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 595
 596    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
 597    constexpr int nrows = warp_size / threads_per_row;
 598    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 599    const int kbx  = txi / QI5_1;
 600    const int kqsx = txi % QI5_1;
 601
 602#pragma unroll
 603    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
 604        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 605
 606        if (need_check) {
 607            i = min(i, i_max);
 608        }
 609
 610        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
 611
 612        const int ql = get_int_b4(bxi->qs, kqsx);
 613        const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
 614
 615        int qs0 = (ql >>  0) & 0x0F0F0F0F;
 616        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
 617        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
 618        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
 619        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
 620
 621        int qs1 = (ql >>  4) & 0x0F0F0F0F;
 622        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
 623        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
 624        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
 625        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 626
 627#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 628        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
 629        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 630#else
 631        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
 632        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 633#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 634    }
 635
 636    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
 637    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
 638    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 639
 640#pragma unroll
 641    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
 642        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 643
 644        if (need_check) {
 645            i = min(i, i_max);
 646        }
 647
 648        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 649
 650#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 651        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 652#else
 653        x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
 654#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 655    }
 656}
 657
 658template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
 659    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 660    constexpr int nwarps = mmq_get_nwarps_device();
 661    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 662
 663#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 664    int   * x_qs = (int   *)  x_tile;
 665    float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
 666#else
 667    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
 668    int   * x_qs = (int   *)  x_tile;
 669    float * x_df = (float *) (x_qs + txs.qs);
 670#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 671
 672    // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
 673    constexpr int threads_per_row = 32;
 674    constexpr int nrows = warp_size / threads_per_row;
 675    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 676    const int kbx  = txi / QI8_0;
 677    const int kqsx = txi % QI8_0;
 678
 679#pragma unroll
 680    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
 681        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 682
 683        if (need_check) {
 684            i = min(i, i_max);
 685        }
 686
 687        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 688
 689#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 690        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
 691        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
 692#else
 693        x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
 694        x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
 695#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 696    }
 697
 698    constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
 699    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
 700    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 701
 702#pragma unroll
 703    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
 704        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 705
 706        if (need_check) {
 707            i = min(i, i_max);
 708        }
 709
 710        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 711
 712#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 713        x_df[i*MMQ_MMA_TILE_X_K_Q8_0                 + kbxd] = bxi->d;
 714#else
 715        x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
 716#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 717    }
 718}
 719
 720template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
 721    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 722    constexpr int nwarps = mmq_get_nwarps_device();
 723    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 724
 725#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 726    int   * x_qs = (int   *)  x_tile;
 727    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 728#else
 729    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
 730    int   * x_qs = (int   *)  x_tile;
 731    float * x_df = (float *) (x_qs + txs.qs);
 732#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 733
 734    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
 735    constexpr int nrows = warp_size / threads_per_row;
 736    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 737    const int kbx  = txi / QI_MXFP4;
 738    const int kqsx = txi % QI_MXFP4;
 739
 740#pragma unroll
 741    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
 742        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 743
 744        if (need_check) {
 745            i = min(i, i_max);
 746        }
 747
 748        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
 749
 750        const int aux_q4 = get_int_b1(bxi->qs, kqsx);
 751        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
 752        const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
 753
 754#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 755        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x;
 756        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
 757#else
 758        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x;
 759        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
 760#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
 761    }
 762
 763    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
 764    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
 765    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 766
 767#pragma unroll
 768    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
 769        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 770
 771        if (need_check) {
 772            i = min(i, i_max);
 773        }
 774
 775        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
 776
 777#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 778        x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
 779#else
 780        x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
 781#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 782    }
 783}
 784
 785template <int mmq_y, bool need_check>
 786static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
 787                                                            int * __restrict__ x_tile,
 788                                                            const int kbx0,
 789                                                            const int i_max,
 790                                                            const int stride) {
 791    constexpr int nwarps = mmq_get_nwarps_device();
 792    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 793
 794    int *      x_qs = (int *) x_tile;
 795    uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
 796
 797    const int txi = threadIdx.x;
 798
 799    constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
 800
 801    constexpr int threads_per_row = iter_k / QK_MXFP4;  // each thread processes 1 block
 802    constexpr int rows_per_warp   = warp_size / threads_per_row;
 803    const int     kbx             = txi % threads_per_row;
 804    const int     row_in_warp     = txi / threads_per_row;
 805
 806#pragma unroll
 807    for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
 808        int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
 809
 810        if constexpr (need_check) {
 811            i = min(i, i_max);
 812        }
 813
 814        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
 815
 816        // quantize_mxfp4_mmq permutes nibbles to match the quantized format
 817        const int k0 = kbx * 4;
 818        memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
 819
 820        // Load E8M0 scales: pack 2 consecutive scales into one uint32
 821        if (kbx % 2 == 0) {
 822            uint32_t e = bxi->e;
 823            e |= ((bxi + 1)->e << 8);
 824            x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
 825        }
 826    }
 827}
 828
 829template <int mmq_x, int mmq_y>
 830static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
 831    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 832    constexpr int nwarps = mmq_get_nwarps_device();
 833    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 834
 835    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
 836    const int   * x_qs = (const int   *) x;
 837    const float * x_df = (const float *) x_qs + txs.qs;
 838    const int   * y_qs = (const int   *) y + 4;
 839    const float * y_df = (const float *) y;
 840
 841// #pragma unroll
 842    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
 843        const int k0 = k00 + k01;
 844
 845#pragma unroll
 846        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 847            const int j = j0 + threadIdx.y;
 848
 849#pragma unroll
 850            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
 851                const int i = i0 + threadIdx.x;
 852
 853                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
 854                    (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
 855                     x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
 856            }
 857        }
 858    }
 859}
 860
 861template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
 862static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 863    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 864#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 865    constexpr data_layout input_layout = get_input_data_layout();
 866    typedef tile<16,  8, int, input_layout>        tile_A;
 867    typedef tile<16,  8, int, input_layout>        tile_B;
 868    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 869
 870    constexpr int granularity = mmq_get_granularity_device(mmq_x);
 871    constexpr int rows_per_warp = granularity;
 872    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 873
 874    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 875
 876    const int   * x_qs = (const int   *) x;
 877    const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
 878    const int   * y_qs = (const int   *) y + 4;
 879    const float * y_df = (const float *) y;
 880    const half2 * y_ds = (const half2 *) y;
 881
 882    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
 883
 884    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
 885        const int k0 = k00 + k01;
 886
 887        tile_A A[ntx];
 888#pragma unroll
 889        for (int n = 0; n < ntx; ++n) {
 890            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
 891        }
 892
 893#pragma unroll
 894        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 895            tile_B B;
 896            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
 897
 898            float dB;
 899            const int j = j0 + tile_C::get_j(0);
 900            if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
 901                dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
 902            } else {
 903                dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
 904            }
 905
 906#pragma unroll
 907            for (int n = 0; n < ntx; ++n) {
 908                tile_C C;
 909                mma(C, A[n], B);
 910
 911#pragma unroll
 912                for (int l = 0; l < tile_C::ne; ++l) {
 913                    const int i = i0 + n*tile_A::I + tile_C::get_i(l);
 914                    const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
 915                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
 916                }
 917            }
 918        }
 919    }
 920#else
 921    typedef tile<16, 8, int> tile_A;
 922    typedef tile< 8, 8, int> tile_B;
 923    typedef tile<16, 8, int> tile_C;
 924
 925    constexpr int granularity = mmq_get_granularity_device(mmq_x);
 926    constexpr int rows_per_warp = 2 * granularity;
 927    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 928
 929    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 930
 931    const int   * x_qs = (const int   *) x;
 932    const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
 933    const int   * y_qs = (const int   *) y + 4;
 934    const float * y_df = (const float *) y;
 935    const half2 * y_ds = (const half2 *) y;
 936
 937    tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
 938    float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
 939
 940    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 941
 942#pragma unroll
 943    for (int n = 0; n < ntx; ++n) {
 944#pragma unroll
 945        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
 946            const int k0 = k00 + k01;
 947
 948            load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
 949        }
 950
 951#pragma unroll
 952        for (int l = 0; l < tile_C::ne/2; ++l) {
 953            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 954
 955#pragma unroll
 956            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
 957                const int k0 = k00 + k01;
 958
 959                dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
 960            }
 961        }
 962    }
 963
 964#pragma unroll
 965    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 966#pragma unroll
 967        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
 968            tile_B B;
 969            float dB[tile_C::ne/2];
 970
 971            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 972
 973#pragma unroll
 974            for (int l = 0; l < tile_C::ne/2; ++l) {
 975                const int j = j0 + tile_C::get_j(l);
 976
 977                if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
 978                    dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
 979                } else {
 980                    dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
 981                }
 982            }
 983
 984#pragma unroll
 985            for (int n = 0; n < ntx; ++n) {
 986                tile_C C;
 987                mma(C, A[n][k01/QI8_0], B);
 988
 989#pragma unroll
 990                for (int l = 0; l < tile_C::ne; ++l) {
 991                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
 992                }
 993            }
 994        }
 995    }
 996#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 997}
 998
 999template <int mmq_x, int mmq_y>
1000static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001                                                               const int * __restrict__ y,
1002                                                               float * __restrict__ sum,
1003                                                               const int k00) {
1004    typedef tile<16, 8, int>   tile_A;
1005    typedef tile<8, 8, int>    tile_B;
1006    typedef tile<16, 8, float> tile_C;  // Output is float for native scaled MMA
1007
1008    constexpr int granularity   = mmq_get_granularity_device(mmq_x);
1009    constexpr int rows_per_warp = 2 * granularity;
1010    constexpr int ntx           = rows_per_warp / tile_C::I;  // Number of x minitiles per warp.
1011
1012    y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
1014    // Match layout from load_tiles_mxfp4_fp4
1015    const int *      x_qs = (const int *) x;
1016    const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017    const int *      y_qs = (const int *) y + 4;
1018    const uint32_t * y_sc = (const uint32_t *) y;
1019
1020    // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021    tile_A   A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022    uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
1024    // Block scale
1025    // Each thread has to point to a 4 byte scale value
1026    // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
1028    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
1030#pragma unroll
1031    for (int n = 0; n < ntx; ++n) {
1032#pragma unroll
1033        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034            const int k0 = k00 + k01;
1035
1036            load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037                          MMQ_MMA_TILE_X_K_FP4);
1038
1039            // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040            const int tidx         = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041            scaleA[n][k01 / (2 * QI_MXFP4)] =
1042                *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043        }
1044    }
1045
1046#pragma unroll
1047    for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048#pragma unroll
1049        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050            tile_B   B;
1051            uint32_t scaleB;  // 2xN scales
1052
1053            load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
1055            scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
1057#pragma unroll
1058            for (int n = 0; n < ntx; ++n) {
1059                tile_C C;
1060
1061                mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062#pragma unroll
1063                for (int l = 0; l < tile_C::ne; ++l) {
1064                    sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065                }
1066            }
1067        }
1068    }
1069}
1070
1071template <int mmq_x, int mmq_y>
1072static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
1073    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1074    constexpr int nwarps = mmq_get_nwarps_device();
1075    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1076
1077    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
1078    const int   * x_qs = (const int   *) x;
1079    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1080    const int   * y_qs = (const int   *) y + 4;
1081    const half2 * y_ds = (const half2 *) y;
1082
1083// #pragma unroll
1084    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
1085        const int k0 = k00 + k01;
1086
1087#pragma unroll
1088        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1089            const int j = j0 + threadIdx.y;
1090
1091#pragma unroll
1092            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1093                const int i = i0 + threadIdx.x;
1094
1095                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
1096                    (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1097                    x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1098            }
1099        }
1100    }
1101}
1102
1103template <int mmq_x, int mmq_y>
1104static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1105    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1106#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1107    constexpr data_layout input_layout = get_input_data_layout();
1108    typedef tile<16,  8, int, input_layout>        tile_A;
1109    typedef tile<16,  8, int, input_layout>        tile_B;
1110    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1111
1112    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1113    constexpr int rows_per_warp = granularity;
1114    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1115
1116    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1117
1118    const int   * x_qs = (const int   *) x;
1119    const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
1120    const int   * y_qs = (const int   *) y + 4;
1121    const half2 * y_dm = (const half2 *) y;
1122
1123    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1124
1125    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1126        const int k0 = k00 + k01;
1127
1128        tile_A A[ntx];
1129#pragma unroll
1130        for (int n = 0; n < ntx; ++n) {
1131            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1132        }
1133
1134#pragma unroll
1135        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1136            tile_B B;
1137            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1138
1139            const int j = j0 + tile_C::get_j(0);
1140            const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1141
1142#pragma unroll
1143            for (int n = 0; n < ntx; ++n) {
1144                tile_C C;
1145                mma(C, A[n], B);
1146
1147#pragma unroll
1148                for (int l = 0; l < tile_C::ne; ++l) {
1149                    const int i = i0 + n*tile_A::I + tile_C::get_i(l);
1150                    float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1151                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
1152                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
1153                }
1154            }
1155        }
1156    }
1157#else
1158    typedef tile<16,  8, int> tile_A;
1159    typedef tile< 8,  8, int> tile_B;
1160    typedef tile<16,  8, int> tile_C;
1161
1162    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1163    constexpr int rows_per_warp = 2 * granularity;
1164    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1165
1166    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1167
1168    const int   * x_qs = (const int   *) x;
1169    const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
1170    const int   * y_qs = (const int   *) y + 4;
1171    const half2 * y_dm = (const half2 *) y;
1172
1173    tile_A   A[ntx][MMQ_TILE_NE_K/QI8_1];
1174    float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
1175
1176    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
1177
1178#pragma unroll
1179    for (int n = 0; n < ntx; ++n) {
1180#pragma unroll
1181        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1182            const int k0 = k00 + k01;
1183
1184            load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1185        }
1186
1187#pragma unroll
1188        for (int l = 0; l < tile_C::ne/2; ++l) {
1189            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
1190
1191#pragma unroll
1192            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1193                const int k0 = k00 + k01;
1194
1195                dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1196            }
1197        }
1198    }
1199
1200#pragma unroll
1201    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1202#pragma unroll
1203        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1204            tile_B   B;
1205            float2 dsB[tile_C::ne/2];
1206
1207            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
1208
1209#pragma unroll
1210            for (int l = 0; l < tile_C::ne/2; ++l) {
1211                const int j = j0 + tile_C::get_j(l);
1212
1213                dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1214            }
1215
1216#pragma unroll
1217            for (int n = 0; n < ntx; ++n) {
1218                tile_C C;
1219                mma(C, A[n][k01/QI8_1], B);
1220
1221#pragma unroll
1222                for (int l = 0; l < tile_C::ne; ++l) {
1223                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
1224                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
1225                }
1226            }
1227        }
1228    }
1229#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1230}
1231
1232// Used for Q3_K, IQ2_S, and IQ2_XS
1233template <int mmq_x, int mmq_y>
1234static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1235    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1236    constexpr int nwarps = mmq_get_nwarps_device();
1237    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1238
1239    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1240    const int   * x_qs = (const int   *) x;
1241    const float * x_df = (const float *) x_qs + txs.qs;
1242    const int   * y_qs = (const int   *) y + 4;
1243    const float * y_df = (const float *) y;
1244
1245// #pragma unroll
1246    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
1247        const int k0 = k00 + k01;
1248
1249#pragma unroll
1250        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1251            const int j = j0 + threadIdx.y;
1252
1253#pragma unroll
1254            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1255                const int i = i0 + threadIdx.x;
1256
1257                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
1258                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
1259                    &y_qs[j*MMQ_TILE_Y_K + k01],
1260                    &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
1261                    y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1262            }
1263        }
1264    }
1265}
1266
1267// Used for Q3_K, IQ2_S, and IQ2_XS:
1268template <int mmq_x, int mmq_y>
1269static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1270    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1271#if defined(AMD_MFMA_AVAILABLE)
1272    constexpr data_layout input_layout = get_input_data_layout();
1273    typedef tile<16,  8, int, input_layout>        tile_A;
1274    typedef tile<16,  8, int, input_layout>        tile_B;
1275    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276    typedef tile<64,  2, int, input_layout>        tile_load;
1277
1278    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1279    constexpr int rows_per_warp = granularity;
1280    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1281
1282    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1283
1284    const int   * x_qs = (const int   *) x;
1285    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1286    const int   * y_qs = (const int   *) y + 4;
1287    const float * y_df = (const float *) y;
1288
1289    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1290
1291    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1292        const int k0 = k00 + k01;
1293
1294        tile_A A[ntx];
1295#pragma unroll
1296        for (int n = 0; n < ntx; ++n) {
1297            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1298        }
1299
1300#pragma unroll
1301        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1302            tile_B B[1];
1303            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1304
1305            const int j = j0 + tile_C::get_j(0);
1306            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1307
1308#pragma unroll
1309            for (int n = 0; n < ntx; ++n) {
1310                tile_C C;
1311                mma(C, A[n], B[0]);
1312
1313#pragma unroll
1314                for (int l = 0; l < tile_C::ne; ++l) {
1315                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317                }
1318            }
1319        }
1320    }
1321#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1322    constexpr data_layout input_layout = get_input_data_layout();
1323    typedef tile<16,  4, int, input_layout>        tile_A;
1324    typedef tile<16,  4, int, input_layout>        tile_B;
1325    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1326
1327    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1328    constexpr int rows_per_warp = granularity;
1329    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1330
1331    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1332
1333    const int   * x_qs = (const int   *) x;
1334    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1335    const int   * y_qs = (const int   *) y + 4;
1336    const float * y_df = (const float *) y;
1337
1338    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1339
1340    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1341        const int k0 = k00 + k01;
1342
1343        tile_A A[ntx];
1344#pragma unroll
1345        for (int n = 0; n < ntx; ++n) {
1346            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347        }
1348
1349#pragma unroll
1350        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351            tile_B B;
1352            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
1354            const int j = j0 + tile_C::get_j(0);
1355            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1356
1357#pragma unroll
1358            for (int n = 0; n < ntx; ++n) {
1359                tile_C C;
1360                mma(C, A[n], B);
1361
1362#pragma unroll
1363                for (int l = 0; l < tile_C::ne; ++l) {
1364                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1365                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1366                }
1367            }
1368        }
1369    }
1370#elif defined(TURING_MMA_AVAILABLE)
1371
1372    typedef tile<16, 4, int> tile_A;
1373    typedef tile<16, 8, int> tile_A_8;
1374    typedef tile< 8, 4, int> tile_B;
1375    typedef tile<16, 8, int> tile_C;
1376
1377    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1378    constexpr int rows_per_warp = 2 * granularity;
1379    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1380
1381    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1382
1383    const int   * x_qs = (const int   *) x;
1384    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1385    const int   * y_qs = (const int   *) y + 4;
1386    const float * y_df = (const float *) y;
1387
1388    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1389
1390    tile_A  A[ntx][8];
1391    float  dA[ntx][tile_C::ne/2][8];
1392
1393#pragma unroll
1394    for (int n = 0; n < ntx; ++n) {
1395#pragma unroll
1396        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1397            const int k0 = k00 + k01;
1398
1399            load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1400        }
1401
1402#pragma unroll
1403        for (int l = 0; l < tile_C::ne/2; ++l) {
1404            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1405
1406#pragma unroll
1407            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1408                const int k0 = k00 + k01;
1409
1410                dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
1411            }
1412        }
1413    }
1414
1415#pragma unroll
1416    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1417#pragma unroll
1418        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1419            tile_B B[2];
1420            float dB[tile_C::ne/2];
1421
1422            // Here load_generic is faster than load_ldmatrix.
1423            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
1424            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1425
1426#pragma unroll
1427            for (int l = 0; l < tile_C::ne/2; ++l) {
1428                const int j = j0 + tile_C::get_j(l);
1429
1430                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1431            }
1432
1433#pragma unroll
1434            for (int n = 0; n < ntx; ++n) {
1435                tile_C C[2];
1436                mma(C[0], A[n][k01/4 + 0], B[0]);
1437                mma(C[1], A[n][k01/4 + 1], B[1]);
1438
1439#pragma unroll
1440                for (int l = 0; l < tile_C::ne; ++l) {
1441                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
1442                }
1443            }
1444        }
1445    }
1446#else
1447    GGML_UNUSED_VARS(x, y, sum, k00);
1448    NO_DEVICE_CODE;
1449#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1450}
1451
1452template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1453    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1454    constexpr int nwarps = mmq_get_nwarps_device();
1455
1456#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1457    int   * x_qs = (int   *)  x_tile;
1458    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1459#else
1460    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1461    int   * x_qs = (int   *)  x_tile;
1462    half2 * x_dm = (half2 *) (x_qs + txs.qs);
1463#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1464
1465    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1466    constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
1467    const int kqsx = threadIdx.x % threads_per_row;
1468
1469#pragma unroll
1470    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1471        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1472
1473        if (need_check) {
1474            i = min(i, i_max);
1475        }
1476
1477        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
1478
1479        const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1480
1481#pragma unroll
1482        for (int l = 0; l < QR2_K; ++l) {
1483            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1484
1485            const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1486
1487#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1488            x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1489#else
1490            x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1491#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1492        }
1493
1494        const int sc_m = bxi->scales[kqsx];
1495#ifdef FAST_FP16_AVAILABLE
1496        const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
1497#else
1498        const float2 bxi_dmf = __half22float2(bxi->dm);
1499        const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1500#endif // FAST_FP16_AVAILABLE
1501
1502#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1503        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1504#else
1505        x_dm[i*(MMQ_TILE_NE_K + 1)   + kqsx] = x_dm_ik;
1506#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1507    }
1508}
1509
1510template <int mmq_x, int mmq_y>
1511static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1512    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1513    constexpr int nwarps = mmq_get_nwarps_device();
1514    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1515
1516    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1517    const int   * x_qs = (const int   *) x;
1518    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1519    const int   * y_qs = (const int   *) y + 4;
1520    const half2 * y_ds = (const half2 *) y;
1521
1522    float2 y_df[mmq_x/nwarps];
1523#pragma unroll
1524    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1525        const int j = j0 + threadIdx.y;
1526
1527        y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1528    }
1529
1530#pragma unroll
1531    for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1532        const int k0 = k00 + k01;
1533
1534#pragma unroll
1535        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1536            const int j = j0 + threadIdx.y;
1537
1538#pragma unroll
1539            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1540                const int i = i0 + threadIdx.x;
1541
1542                constexpr int ns = 2;
1543                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1544                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1545                    &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1546                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1547            }
1548        }
1549    }
1550
1551    // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
1552    // As a workaround 2 separate loops are used instead.
1553#pragma unroll
1554    for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1555        const int k0 = k00 + k01;
1556
1557#pragma unroll
1558        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1559            const int j = j0 + threadIdx.y;
1560
1561#pragma unroll
1562            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1563                const int i = i0 + threadIdx.x;
1564
1565                constexpr int ns = 1;
1566                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1567                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1568                    &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1569                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1570            }
1571        }
1572    }
1573}
1574
1575template <int mmq_x, int mmq_y>
1576static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1577    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1578#if defined(AMD_MFMA_AVAILABLE)
1579    constexpr data_layout input_layout = get_input_data_layout();
1580    typedef tile<16,  8, int, input_layout>        tile_A;
1581    typedef tile<16,  8, int, input_layout>        tile_B;
1582    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583    typedef tile<64,  2, int, input_layout>        tile_load;
1584
1585    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1586    constexpr int rows_per_warp = granularity;
1587    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1588
1589    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1590
1591    const int   * x_qs = (const int   *) x;
1592    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1593    const int   * y_qs = (const int   *) y + 4;
1594    const half2 * y_ds = (const half2 *) y;
1595
1596    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1597
1598    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1599        const int k0 = k00 + k01;
1600
1601        tile_A A[ntx];
1602#pragma unroll
1603        for (int n = 0; n < ntx; ++n) {
1604            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1605        }
1606
1607#pragma unroll
1608        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1609            tile_B B[1];
1610            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1611
1612            const int j = j0 + tile_C::get_j(0);
1613            const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1614            const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1615                                              : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1616                                                             : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1617
1618            tile_C Cm;
1619            if (k01 >= MMQ_TILE_NE_K * 3/4) {
1620                tile_A A1;
1621                A1.x[0] = 0x01010101;
1622                A1.x[1] = 0x01010101;
1623                mma(Cm, A1, B[0]);
1624            }
1625
1626#pragma unroll
1627            for (int n = 0; n < ntx; ++n) {
1628                tile_C Cd;
1629                mma(Cd, A[n], B[0]);
1630
1631#pragma unroll
1632                for (int l = 0; l < tile_C::ne; ++l) {
1633                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634                    const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635                    float tmp = Cd.x[l]*dm.x;
1636                    if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637                        tmp -= Cm.x[l]*dm.y;
1638                    }
1639                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641                }
1642            }
1643        }
1644    }
1645#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1646    constexpr data_layout input_layout = get_input_data_layout();
1647    typedef tile<16,  4, int, input_layout>        tile_A;
1648    typedef tile<16,  4, int, input_layout>        tile_B;
1649    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1650
1651    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1652    constexpr int rows_per_warp = granularity;
1653    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1654
1655    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1656
1657    const int   * x_qs = (const int   *) x;
1658    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1659    const int   * y_qs = (const int   *) y + 4;
1660    const half2 * y_ds = (const half2 *) y;
1661
1662    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1663
1664    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1665        const int k0 = k00 + k01;
1666
1667        tile_A A[ntx];
1668#pragma unroll
1669        for (int n = 0; n < ntx; ++n) {
1670            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671        }
1672
1673#pragma unroll
1674        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675            tile_B B;
1676            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
1678            const int j = j0 + tile_C::get_j(0);
1679            const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
1680            const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1681                                              : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1682                                                             : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1683
1684            tile_C Cm;
1685            if (k01 >= MMQ_TILE_NE_K * 3/4) {
1686                tile_A A1;
1687#pragma unroll
1688                for (int l = 0; l < tile_A::ne; ++l) {
1689                    A1.x[l] = 0x01010101;
1690                }
1691                mma(Cm, A1, B);
1692            }
1693
1694#pragma unroll
1695            for (int n = 0; n < ntx; ++n) {
1696                tile_C Cd;
1697                mma(Cd, A[n], B);
1698
1699#pragma unroll
1700                for (int l = 0; l < tile_C::ne; ++l) {
1701                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1702                    const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1703                    float tmp = Cd.x[l]*dm.x;
1704                    if (k01 >= MMQ_TILE_NE_K * 3/4) {
1705                        tmp -= Cm.x[l]*dm.y;
1706                    }
1707                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1708                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1709                }
1710            }
1711        }
1712    }
1713#elif defined(TURING_MMA_AVAILABLE)
1714
1715    typedef tile<16, 4, int> tile_A;
1716    typedef tile<16, 8, int> tile_A_8;
1717    typedef tile< 8, 4, int> tile_B;
1718    typedef tile<16, 8, int> tile_C;
1719
1720    constexpr int granularity = mmq_get_granularity_device(mmq_x);
1721    constexpr int rows_per_warp = 2 * granularity;
1722    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1723
1724    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1725
1726    const int   * x_qs = (const int   *) x;
1727    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1728    const int   * y_qs = (const int   *) y + 4;
1729    const half2 * y_ds = (const half2 *) y;
1730
1731    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1732
1733    tile_A  A[ntx][8];
1734    float  dA[ntx][tile_C::ne/2][8];
1735    float  mA[ntx][tile_C::ne/2][8];
1736
1737#pragma unroll
1738    for (int n = 0; n < ntx; ++n) {
1739#pragma unroll
1740        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1741            const int k0 = k00 + k01;
1742
1743            load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1744        }
1745    }
1746
1747#pragma unroll
1748    for (int n = 0; n < ntx; ++n) {
1749#pragma unroll
1750        for (int l = 0; l < tile_C::ne/2; ++l) {
1751            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1752
1753#pragma unroll
1754            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
1755                const int k0 = k00 + k01;
1756
1757                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
1758
1759                dA[n][l][k01/(QI8_1/2)] = dm.x;
1760                mA[n][l][k01/(QI8_1/2)] = dm.y;
1761            }
1762        }
1763    }
1764
1765#pragma unroll
1766    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1767        float2 dB[tile_C::ne/2];
1768
1769#pragma unroll
1770        for (int l = 0; l < tile_C::ne/2; ++l) {
1771            const int j = j0 + tile_C::get_j(l);
1772
1773            dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1774        }
1775
1776#pragma unroll
1777        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1778            tile_B B[2];
1779
1780            // Here load_generic is faster than load_ldmatrix.
1781            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
1782            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1783
1784            tile_C Cm[2];
1785            if (k01 >= MMQ_TILE_NE_K * 3/4) {
1786                tile_A A1;
1787                A1.x[0] = 0x01010101;
1788                A1.x[1] = 0x01010101;
1789                mma(Cm[0], A1, B[0]);
1790                mma(Cm[1], A1, B[1]);
1791            }
1792
1793#pragma unroll
1794            for (int n = 0; n < ntx; ++n) {
1795                tile_C Cd[2];
1796
1797                mma(Cd[0], A[n][k01/4 + 0], B[0]);
1798                mma(Cd[1], A[n][k01/4 + 1], B[1]);
1799
1800#pragma unroll
1801                for (int l = 0; l < tile_C::ne; ++l) {
1802                    float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1803                    if (k01 >= MMQ_TILE_NE_K * 3/4) {
1804                        tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1805                    }
1806                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
1807                }
1808            }
1809        }
1810
1811#pragma unroll
1812        for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
1813            float2 sB[tile_C::ne/2];
1814
1815#pragma unroll
1816            for (int l = 0; l < tile_C::ne/2; ++l) {
1817                const int j = j0 + tile_C::get_j(l);
1818
1819                sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1820            }
1821
1822#pragma unroll
1823            for (int n = 0; n < ntx; ++n) {
1824#pragma unroll
1825                for (int l = 0; l < tile_C::ne; ++l) {
1826                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
1827                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
1828                }
1829            }
1830        }
1831    }
1832#else
1833    GGML_UNUSED_VARS(x, y, sum, k00);
1834    NO_DEVICE_CODE;
1835#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1836}
1837
1838template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1839    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1840    constexpr int nwarps = mmq_get_nwarps_device();
1841    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1842
1843#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1844    int   * x_qs = (int   *)  x_tile;
1845    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1846#else
1847    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1848    int   * x_qs = (int   *)  x_tile;
1849    float * x_df = (float *) (x_qs + txs.qs);
1850    int   * x_sc = (int   *) (x_df + txs.dm);
1851#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1852
1853    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1854    constexpr int nrows = warp_size / threads_per_row;
1855    const int kqsx = threadIdx.x % threads_per_row;
1856
1857#pragma unroll
1858    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1859        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1860
1861        if (need_check) {
1862            i = min(i, i_max);
1863        }
1864
1865        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1866
1867        const int x_ql_0 = get_int_b2(bxi->qs,    kqsx);
1868        const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1869
1870#pragma unroll
1871        for (int l = 0; l < QR3_K; ++l) {
1872            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1873
1874            const int x_ql_k =  (x_ql_0 >> (2*l))       & 0x03030303;
1875            const int x_qh_k = ((x_qh_0 >>    l)  << 2) & 0x04040404;
1876
1877            const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1878
1879#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1880            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1881#else
1882            x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1883#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1884        }
1885    }
1886
1887    constexpr int rows_per_warp = warp_size / 4;
1888#pragma unroll
1889    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1890        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
1891
1892        if (need_check) {
1893            i = min(i, i_max);
1894        }
1895
1896        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1897
1898        const int ksc = threadIdx.x % 4;
1899
1900        const int ksc_low = ksc % (QI3_K/8);
1901        const int shift_low = 4 * (ksc / (QI3_K/8));
1902        const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
1903
1904        const int ksc_high = QI3_K/8;
1905        const int shift_high = 2 * ksc;
1906        const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
1907
1908        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1909
1910#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1911        const int8_t * sc8 = (const int8_t *) &sc;
1912        const float d = bxi->d;
1913
1914#pragma unroll
1915        for (int l = 0; l < int(sizeof(int)); ++l) {
1916            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1917        }
1918#else
1919        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1920#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1921    }
1922
1923#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
1924#pragma unroll
1925    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1926        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1927
1928        if (need_check) {
1929            i = min(i, i_max);
1930        }
1931
1932        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1933
1934        x_df[i] = bxi->d;
1935    }
1936#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
1937}
1938
1939template <int mmq_x, int mmq_y>
1940static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1941    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1942    constexpr int nwarps = mmq_get_nwarps_device();
1943    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1944
1945    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1946    const int   * x_qs = (const int   *) x;
1947    const float * x_df = (const float *) x_qs + txs.qs;
1948    const int   * x_sc = (const int   *) x_df + txs.dm;
1949    const int   * y_qs = (const int   *) y + 4;
1950    const float * y_df = (const float *) y;
1951
1952// #pragma unroll
1953    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1954        const int k0 = k00 + k01;
1955
1956#pragma unroll
1957        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1958            const int j = j0 + threadIdx.y;
1959
1960#pragma unroll
1961            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1962                const int i = i0 + threadIdx.x;
1963
1964                const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
1965
1966                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
1967                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1968                    x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1969            }
1970        }
1971    }
1972}
1973
1974static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
1975    // scale arrangement after the following two lines:
1976    //   - ksc == 0: sc0, sc1, sc2, sc3
1977    //   - ksc == 1: sc4, sc5, sc6, sc7
1978    //   - ksc == 2:  m0,  m1,  m2,  m3
1979    //   - ksc == 3:  m4,  m5,  m6,  m7
1980    return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
1981           ((scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030);  // upper 2 bits
1982}
1983
1984template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1985    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1986    constexpr int nwarps = mmq_get_nwarps_device();
1987    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1988
1989#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1990    int   * x_qs = (int   *)  x_tile;
1991    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1992#else
1993    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1994    int   * x_qs = (int   *)  x_tile;
1995    half2 * x_dm = (half2 *) (x_qs + txs.qs);
1996    int   * x_sc = (int   *) (x_dm + txs.dm);
1997#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1998
1999    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
2000    constexpr int nrows = warp_size / threads_per_row;
2001    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2002
2003#pragma unroll
2004    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2005        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2006
2007        if (need_check) {
2008            i = min(i, i_max);
2009        }
2010
2011        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2012        const int qs0 = get_int_b4(bxi->qs, txi);
2013
2014#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2015        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
2016        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
2017#else
2018        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
2019#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2020    }
2021
2022#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2023    constexpr int rows_per_warp = warp_size / 2;
2024#pragma unroll
2025    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2026#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2027        // Need if on AMD instead of % because warp_size == 64
2028        // This causes double work and throughput loss (MI300X)
2029        // H100 loses about 100 t/s with 'if' condition over '%'
2030        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
2031        if (i < mmq_y) {
2032#else
2033        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
2034        {
2035#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2036            if (need_check) {
2037                i = min(i, i_max);
2038            }
2039
2040            const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2041
2042            const int * scales = (const int *) bxi->scales;
2043            const int ksc = threadIdx.x % 2;
2044
2045            const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
2046            const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
2047
2048            const uint8_t * sc8 = (const uint8_t *) &sc32;
2049            const uint8_t *  m8 = (const uint8_t *)  &m32;
2050
2051            const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
2052
2053    #pragma unroll
2054            for (int l = 0; l < sizeof(int); ++l) {
2055                x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2056            }
2057        }
2058    }
2059#else
2060#pragma unroll
2061    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2062        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2063
2064        if (need_check) {
2065            i = min(i, i_max);
2066        }
2067
2068        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2069
2070        x_dm[i] = bxi->dm;
2071    }
2072    constexpr int rows_per_warp = warp_size / 4;
2073#pragma unroll
2074    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2075        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2076
2077        if (need_check) {
2078            i = min(i, i_max);
2079        }
2080
2081        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
2082
2083        const int * scales = (const int *) bxi->scales;
2084
2085        const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
2086        const int scales8 = unpack_scales_q45_K(scales, ksc);
2087
2088        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
2089    }
2090#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2091}
2092
2093template <int mmq_x, int mmq_y>
2094static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
2095    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2096    constexpr int nwarps = mmq_get_nwarps_device();
2097    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2098
2099    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
2100    const int   * x_qs = (const int   *) x;
2101    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
2102    const int   * x_sc = (const int   *) x_dm + txs.dm;
2103    const int   * y_qs = (const int   *) y + 4;
2104    const half2 * y_ds = (const half2 *) y;
2105
2106// #pragma unroll
2107    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
2108        const int k0 = k00 + k01;
2109
2110#pragma unroll
2111        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2112            const int j = j0 + threadIdx.y;
2113
2114#pragma unroll
2115            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2116                const int i = i0 + threadIdx.x;
2117
2118                const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
2119
2120                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
2121                    &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2122                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
2123            }
2124        }
2125    }
2126}
2127
2128template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
2129    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2130    constexpr int nwarps = mmq_get_nwarps_device();
2131    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2132
2133#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2134    int   * x_qs = (int   *)  x_tile;
2135    half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2136#else
2137    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
2138    int   * x_qs = (int   *)  x_tile;
2139    half2 * x_dm = (half2 *) (x_qs + txs.qs);
2140    int   * x_sc = (int   *) (x_dm + txs.dm);
2141#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2142
2143    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
2144    constexpr int nrows = warp_size / threads_per_row;
2145    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2146
2147#pragma unroll
2148    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2149        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2150
2151        if (need_check) {
2152            i = min(i, i_max);
2153        }
2154
2155        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2156        const int ky = QR5_K*txi;
2157
2158        const int ql = get_int_b4(bxi->qs, txi);
2159        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
2160        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
2161
2162        const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
2163        const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
2164        const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
2165
2166        const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
2167        const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
2168
2169#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2170        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
2171        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
2172#else
2173        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
2174        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
2175#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2176    }
2177
2178#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2179    constexpr int rows_per_warp = warp_size / 2;
2180#pragma unroll
2181    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2182#if defined(AMD_MFMA_AVAILABLE)
2183        // Need if on AMD instead of % because warp_size == 64
2184        // This causes double work and throughput loss (MI300X)
2185        // H100 loses about 100 t/s with 'if' condition over '%'
2186        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
2187        if (i < mmq_y) {
2188#else
2189        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
2190        {
2191#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2192            if (need_check) {
2193                i = min(i, i_max);
2194            }
2195
2196            const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2197
2198            const int * scales = (const int *) bxi->scales;
2199            const int ksc = threadIdx.x % 2;
2200
2201            const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
2202            const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
2203
2204            const uint8_t * sc8 = (const uint8_t *) &sc32;
2205            const uint8_t *  m8 = (const uint8_t *)  &m32;
2206
2207            const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
2208
2209#pragma unroll
2210            for (int l = 0; l < int(sizeof(int)); ++l) {
2211                x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2212            }
2213        }
2214    }
2215#else
2216#pragma unroll
2217    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2218        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2219
2220        if (need_check) {
2221            i = min(i, i_max);
2222        }
2223
2224        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2225
2226        x_dm[i] = bxi->dm;
2227    }
2228
2229    constexpr int rows_per_warp = warp_size / 4;
2230#pragma unroll
2231    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2232        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2233
2234        if (need_check) {
2235            i = min(i, i_max);
2236        }
2237
2238        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2239
2240        const int * scales = (const int *) bxi->scales;
2241
2242        const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
2243        const int scales8 = unpack_scales_q45_K(scales, ksc);
2244
2245        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
2246    }
2247#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2248}
2249
2250template <int mmq_x, int mmq_y>
2251static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
2252    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2253    constexpr int nwarps = mmq_get_nwarps_device();
2254    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2255
2256    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
2257    const int   * x_qs = (const int   *) x;
2258    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
2259    const int   * x_sc = (const int   *) x_dm + txs.dm;
2260    const int   * y_qs = (const int   *) y + 4;
2261    const half2 * y_ds = (const half2 *) y;
2262
2263// #pragma unroll
2264    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
2265        const int k0 = k00 + k01;
2266
2267#pragma unroll
2268        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2269            const int j = j0 + threadIdx.y;
2270
2271#pragma unroll
2272            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2273                const int i = i0 + threadIdx.x;
2274
2275                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
2276
2277                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
2278                    &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2279                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
2280            }
2281        }
2282    }
2283}
2284
2285template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
2286    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2287    constexpr int nwarps = mmq_get_nwarps_device();
2288    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2289
2290#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2291    int   * x_qs = (int   *)  x_tile;
2292    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2293    int   * x_sc = (int   *) (x_df + MMQ_TILE_NE_K/QI6_K);
2294#else
2295    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
2296    int   * x_qs = (int   *)  x_tile;
2297    float * x_df = (float *) (x_qs + txs.qs);
2298    int   * x_sc = (int   *) (x_df + txs.dm);
2299#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2300
2301    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2302    constexpr int nrows = warp_size / threads_per_row;
2303    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2304
2305#pragma unroll
2306    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2307        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2308
2309        if (need_check) {
2310            i = min(i, i_max);
2311        }
2312
2313        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2314
2315        const int ql = get_int_b2(bxi->ql, txi);
2316        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
2317        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
2318
2319        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
2320        const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
2321        const int qh1 =  (qh >> ((txi & 0x08) >> 2))       & 0x30303030;
2322
2323        const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2324        const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2325
2326#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2327        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2328        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2329#else
2330        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2331        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2332#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2333    }
2334
2335#pragma unroll
2336    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2337        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2338
2339        if (need_check) {
2340            i = min(i, i_max);
2341        }
2342
2343        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2344
2345#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2346        x_df[i*MMQ_MMA_TILE_X_K_Q6_K]           = bxi->d;
2347#else
2348        x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2349#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2350    }
2351
2352    constexpr int rows_per_warp = warp_size / 4;
2353#pragma unroll
2354    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2355        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2356
2357        if (need_check) {
2358            i = min(i, i_max);
2359        }
2360
2361        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2362
2363#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2364        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2365#else
2366        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2367#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2368    }
2369}
2370
2371template <int mmq_x, int mmq_y>
2372static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
2373    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2374    constexpr int nwarps = mmq_get_nwarps_device();
2375    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2376
2377    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
2378    const int   * x_qs = (const int   *) x;
2379    const float * x_df = (const float *) x_qs + txs.qs;
2380    const int   * x_sc = (const int   *) x_df + txs.dm;
2381    const int   * y_qs = (const int   *) y + 4;
2382    const float * y_df = (const float *) y;
2383
2384// #pragma unroll
2385    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2386        const int k0 = k00 + k01;
2387
2388#pragma unroll
2389        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2390            const int j = j0 + threadIdx.y;
2391
2392#pragma unroll
2393            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2394                const int i = i0 + threadIdx.x;
2395
2396                const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
2397
2398                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
2399                    &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2400                    x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2401            }
2402        }
2403    }
2404}
2405
2406template <int mmq_x, int mmq_y>
2407static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2408    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2409#if defined(AMD_MFMA_AVAILABLE)
2410    constexpr data_layout input_layout = get_input_data_layout();
2411    typedef tile<16,  8, int, input_layout>        tile_A;
2412    typedef tile<16,  8, int, input_layout>        tile_B;
2413    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414    typedef tile<64,  2, int, input_layout>        tile_load;
2415
2416    constexpr int granularity = mmq_get_granularity_device(mmq_x);
2417    constexpr int rows_per_warp = granularity;
2418    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2419
2420    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2421
2422    const int   * x_qs = (const int   *) x;
2423    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2424    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;
2425    const int   * y_qs = (const int   *) y + 4;
2426    const float * y_df = (const float *) y;
2427
2428    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2429
2430    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2431        const int k0 = k00 + k01;
2432
2433        tile_A A[ntx];
2434#pragma unroll
2435        for (int n = 0; n < ntx; ++n) {
2436            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2437        }
2438
2439#pragma unroll
2440        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2441            tile_B B[1];
2442            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2443
2444            const int j = j0 + tile_C::get_j(0);
2445            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2446
2447#pragma unroll
2448            for (int n = 0; n < ntx; ++n) {
2449                tile_C C;
2450                mma(C, A[n], B[0]);
2451
2452#pragma unroll
2453                for (int l = 0; l < tile_C::ne; ++l) {
2454                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455                    const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457                }
2458            }
2459        }
2460    }
2461#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2462    constexpr data_layout input_layout = get_input_data_layout();
2463    typedef tile<16,  4, int, input_layout>        tile_A;
2464    typedef tile<16,  4, int, input_layout>        tile_B;
2465    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2466
2467    constexpr int granularity = mmq_get_granularity_device(mmq_x);
2468    constexpr int rows_per_warp = granularity;
2469    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2470
2471    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2472
2473    const int   * x_qs = (const int   *) x;
2474    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2475    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;
2476    const int   * y_qs = (const int   *) y + 4;
2477    const float * y_df = (const float *) y;
2478
2479    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2480
2481    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2482        const int k0 = k00 + k01;
2483
2484        tile_A A[ntx];
2485#pragma unroll
2486        for (int n = 0; n < ntx; ++n) {
2487            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488        }
2489
2490#pragma unroll
2491        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492            tile_B B;
2493            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
2495            const int j = j0 + tile_C::get_j(0);
2496            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2497
2498#pragma unroll
2499            for (int n = 0; n < ntx; ++n) {
2500                tile_C C;
2501                mma(C, A[n], B);
2502
2503#pragma unroll
2504                for (int l = 0; l < tile_C::ne; ++l) {
2505                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2506                    const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2507                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2508                }
2509            }
2510        }
2511    }
2512#elif defined(TURING_MMA_AVAILABLE)
2513
2514    typedef tile<16, 4, int> tile_A;
2515    typedef tile< 8, 4, int> tile_B;
2516    typedef tile<16, 8, int> tile_C;
2517
2518    constexpr int granularity = mmq_get_granularity_device(mmq_x);
2519    constexpr int rows_per_warp = 2 * granularity;
2520    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2521
2522    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2523
2524    const int   * x_qs = (const int   *) x;
2525    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2526    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;
2527    const int   * y_qs = (const int   *) y + 4;
2528    const float * y_df = (const float *) y;
2529
2530    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
2531
2532    tile_A   A[ntx][8];
2533    int    scA[ntx][tile_C::ne/2][8];
2534    float   dA[ntx][tile_C::ne/2];
2535
2536#pragma unroll
2537    for (int n = 0; n < ntx; ++n) {
2538#pragma unroll
2539        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
2540            const int k0 = k00 + k01;
2541
2542            load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),         MMQ_MMA_TILE_X_K_Q6_K);
2543            load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
2544        }
2545
2546#pragma unroll
2547        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
2548            const int k0 = k00 + k01;
2549
2550#pragma unroll
2551            for (int l = 0; l < tile_C::ne/2; ++l) {
2552                const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2553
2554                const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
2555                const int8_t * sc        = (const int8_t *) &sc_packed;
2556
2557#pragma unroll
2558                for (int ksc = 0; ksc < sizeof(int); ++ksc) {
2559                    scA[n][l][k01/4 + ksc] = sc[ksc];
2560                }
2561            }
2562        }
2563
2564#pragma unroll
2565        for (int l = 0; l < tile_C::ne/2; ++l) {
2566            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2567
2568            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2569        }
2570    }
2571
2572#pragma unroll
2573    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2574        float tmp[ntx][tile_C::ne] = {{0.0f}};
2575
2576#pragma unroll
2577        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
2578            tile_B B[2];
2579            float dB[tile_C::ne/2];
2580
2581            // Here load_generic is faster than load_ldmatrix.
2582            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0         + k01, MMQ_TILE_Y_K);
2583            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2584
2585#pragma unroll
2586            for (int l = 0; l < tile_C::ne/2; ++l) {
2587                const int j = j0 + tile_C::get_j(l);
2588
2589                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2590            }
2591
2592#pragma unroll
2593            for (int n = 0; n < ntx; ++n) {
2594                tile_C C[2];
2595                mma(C[0], A[n][k01/4 + 0], B[0]);
2596                mma(C[1], A[n][k01/4 + 1], B[1]);
2597
2598#pragma unroll
2599                for (int l = 0; l < tile_C::ne; ++l) {
2600                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
2601                }
2602            }
2603        }
2604
2605#pragma unroll
2606        for (int n = 0; n < ntx; ++n) {
2607#pragma unroll
2608            for (int l = 0; l < tile_C::ne; ++l) {
2609                sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
2610            }
2611        }
2612    }
2613#else
2614    GGML_UNUSED_VARS(x, y, sum, k00);
2615    NO_DEVICE_CODE;
2616#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
2617}
2618
2619template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2620    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2621    constexpr int nwarps = mmq_get_nwarps_device();
2622    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2623
2624#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2625    int   * x_qs = (int   *)  x_tile;
2626    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2627#else
2628    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2629    int   * x_qs = (int   *)  x_tile;
2630    float * x_df = (float *) (x_qs + txs.qs);
2631#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2632
2633    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2634    constexpr int nrows = warp_size / threads_per_row;
2635    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2636    const int kbx  = txi / QI4_NL;
2637    const int kqsx = txi % QI4_NL;
2638
2639#pragma unroll
2640    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2641        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2642
2643        if (need_check) {
2644            i = min(i, i_max);
2645        }
2646
2647        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
2648
2649        const int aux_q4 = get_int_b2(bxi->qs, kqsx);
2650        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2651        const int k0 = kbx * (2 * QI4_NL) + kqsx;
2652
2653#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2654        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0]      = v.x;
2655        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2656#else
2657        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]      = v.x;
2658        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2659#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2660    }
2661
2662    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
2663    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
2664    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
2665
2666#pragma unroll
2667    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2668        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
2669
2670        if (need_check) {
2671            i = min(i, i_max);
2672        }
2673
2674        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2675
2676#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2677        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = __half2float(bxi->d);
2678#else
2679        x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2680#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2681    }
2682}
2683
2684template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
2685    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2686    constexpr int nwarps = mmq_get_nwarps_device();
2687    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2688
2689#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2690    int   * x_qs = (int   *)  x_tile;
2691    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2692#else
2693    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2694    int   * x_qs = (int   *)  x_tile;
2695    float * x_df = (float *) (x_qs + txs.qs);
2696#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2697
2698    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2699    constexpr int nrows = warp_size / threads_per_row;
2700    const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2701
2702#pragma unroll
2703    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2704        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2705
2706        if (need_check) {
2707            i = min(i, i_max);
2708        }
2709
2710        const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
2711
2712        const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
2713        const uint8_t * aux8 = (const uint8_t *) &q2;
2714        const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
2715
2716#pragma unroll
2717        for (int l = 0; l < QR2_XXS; ++l) {
2718            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
2719            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
2720
2721            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
2722            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
2723
2724            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2725            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2726
2727#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2728            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2729            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2730#else
2731            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2732            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2733#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2734        }
2735
2736        const int ls = aux32 >> 28;
2737        const float d = bxi->d;
2738#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2739        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
2740#else
2741        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2742#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
2743    }
2744}
2745
2746template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
2747    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2748    constexpr int nwarps = mmq_get_nwarps_device();
2749    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2750
2751#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2752    int   * x_qs = (int   *)  x_tile;
2753    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2754#else
2755    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2756    int   * x_qs = (int   *)  x_tile;
2757    float * x_df = (float *) (x_qs + txs.qs);
2758#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2759
2760    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2761    constexpr int nrows = warp_size / threads_per_row;
2762    const int kqsx = threadIdx.x % threads_per_row;
2763
2764#pragma unroll
2765    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2766        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2767
2768        if (need_check) {
2769            i = min(i, i_max);
2770        }
2771
2772        const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
2773
2774        const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2775        const uint16_t * q2 = (const uint16_t *) &q2_packed;
2776
2777    #pragma unroll
2778        for (int l = 0; l < QR2_XS; ++l) {
2779            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
2780            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
2781
2782            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2783            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2784
2785#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2786            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2787            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2788#else
2789            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2790            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2791#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2792        }
2793
2794        const int ls = bxi->scales[kqsx];
2795        const float d = bxi->d;
2796#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2797        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
2798        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
2799#else
2800        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
2801        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
2802#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2803    }
2804}
2805
2806template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2807    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2808    constexpr int nwarps = mmq_get_nwarps_device();
2809    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2810
2811#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2812    int   * x_qs = (int   *)  x_tile;
2813    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2814#else
2815    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2816    int   * x_qs = (int   *)  x_tile;
2817    float * x_df = (float *) (x_qs + txs.qs);
2818#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2819    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2820    constexpr int nrows = warp_size / threads_per_row;
2821    const int kqsx = threadIdx.x % threads_per_row;
2822
2823#pragma unroll
2824    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2825        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2826
2827        if (need_check) {
2828            i = min(i, i_max);
2829        }
2830
2831        const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
2832
2833        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
2834        const uint8_t * qs        = (const uint8_t *) &qs_packed;
2835
2836        const int qh = bxi->qh[kqsx];
2837
2838        const int       signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
2839        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
2840
2841#pragma unroll
2842        for (int l = 0; l < QR2_S; ++l) {
2843            const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
2844
2845            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
2846            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
2847
2848            const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2849            const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2850
2851#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2852            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2853            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2854#else
2855            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2856            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2857#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2858        }
2859
2860        const int ls = bxi->scales[kqsx];
2861        const float d = bxi->d;
2862#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2863        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
2864        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
2865#else
2866        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
2867        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
2868#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2869    }
2870}
2871
2872template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2873    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2874    constexpr int nwarps = mmq_get_nwarps_device();
2875    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2876
2877#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2878    int   * x_qs = (int   *)  x_tile;
2879    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2880#else
2881    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2882    int   * x_qs = (int   *)  x_tile;
2883    float * x_df = (float *) (x_qs + txs.qs);
2884#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2885
2886    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2887    constexpr int nrows = warp_size / threads_per_row;
2888    const int kqsx = threadIdx.x % threads_per_row;
2889
2890#pragma unroll
2891    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2892        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2893
2894        if (need_check) {
2895            i = min(i, i_max);
2896        }
2897
2898        const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
2899
2900        const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2901        const uint8_t * q3 = (const uint8_t *) &q3_packed;
2902        const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
2903
2904#pragma unroll
2905        for (int l = 0; l < QR3_XXS; ++l) {
2906            const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
2907
2908            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
2909
2910            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2911            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2912
2913#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2914            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2915            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2916#else
2917            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2918            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2919#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2920        }
2921
2922        const int ls = aux32 >> 28;
2923        const float d = bxi->d;
2924#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2925        x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = (ls*d + d/2)/2;
2926#else
2927        x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = (ls*d + d/2)/2;
2928#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2929    }
2930}
2931
2932template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2933    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2934    constexpr int nwarps = mmq_get_nwarps_device();
2935    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2936
2937#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2938    int   * x_qs = (int   *)  x_tile;
2939    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2940#else
2941    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2942    int   * x_qs = (int   *)  x_tile;
2943    float * x_df = (float *) (x_qs + txs.qs);
2944#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2945
2946    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2947    constexpr int nrows = warp_size / threads_per_row;
2948    const int kqsx = threadIdx.x % threads_per_row;
2949
2950#pragma unroll
2951    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2952        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2953
2954        if (need_check) {
2955            i = min(i, i_max);
2956        }
2957
2958        const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
2959
2960        const int2      qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2961        const uint8_t * qs        = (const uint8_t *) &qs_packed;
2962
2963        const int qh = bxi->qh[kqsx];
2964
2965        const int       signs_packed_32 = get_int_b2(bxi->signs, kqsx);
2966        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
2967
2968#pragma unroll
2969        for (int l = 0; l < QR3_S; ++l) {
2970            const int2 grid_pos = make_int2(
2971                iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
2972                iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
2973
2974            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
2975            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
2976
2977            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2978            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2979
2980#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2981            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2982            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2983#else
2984            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2985            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2986#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2987        }
2988
2989        const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2990        const float d = bxi->d;
2991#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2992        x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = ls*d;
2993#else
2994        x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = ls*d;
2995#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2996    }
2997}
2998
2999template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
3000    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
3001    constexpr int nwarps = mmq_get_nwarps_device();
3002    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3003
3004#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3005    int   * x_qs = (int   *)  x_tile;
3006    half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
3007#else
3008    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
3009    int   * x_qs = (int   *)  x_tile;
3010    half2 * x_ds = (half2 *) (x_qs + txs.qs);
3011#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3012
3013    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
3014    constexpr int nrows = warp_size / threads_per_row;
3015    const int kqsx = threadIdx.x % threads_per_row;
3016
3017#pragma unroll
3018    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
3019        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
3020
3021        if (need_check) {
3022            i = min(i, i_max);
3023        }
3024
3025        const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
3026
3027        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
3028        const uint8_t * qs        = (const uint8_t *) &qs_packed;
3029
3030        const int qh = bxi->qh[kqsx];
3031
3032    #pragma unroll
3033        for (int l = 0; l < QR1_S/2; ++l) {
3034            const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
3035
3036            const int grid0 = (grid >> 0) & 0x0F0F0F0F;
3037            const int grid1 = (grid >> 4) & 0x0F0F0F0F;
3038
3039#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3040            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
3041            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
3042#else
3043            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
3044            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
3045#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3046        }
3047
3048        const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
3049        const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
3050
3051#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3052        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1     + kqsx] = make_half2(d1q, d1q*delta);
3053#else
3054        x_ds[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
3055#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3056    }
3057}
3058
3059template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
3060    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
3061    constexpr int nwarps = mmq_get_nwarps_device();
3062    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3063
3064#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3065    int   * x_qs = (int   *)  x_tile;
3066    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
3067#else
3068    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
3069    int   * x_qs = (int   *)  x_tile;
3070    float * x_df = (float *) (x_qs + txs.qs);
3071#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3072
3073    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
3074    constexpr int nrows = warp_size / threads_per_row;
3075    const int kqsx = threadIdx.x % threads_per_row;
3076
3077#pragma unroll
3078    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
3079        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
3080
3081        if (need_check) {
3082            i = min(i, i_max);
3083        }
3084
3085        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
3086
3087        const int aux_q4 = get_int_b4(bxi->qs, kqsx);
3088        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
3089        const int k0 = 8 * (kqsx / 4) + kqsx % 4;
3090
3091#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3092        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
3093        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
3094#else
3095        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
3096        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
3097#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3098    }
3099
3100    constexpr int rows_per_warp = warp_size / 8;
3101#pragma unroll
3102    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
3103        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
3104
3105        if (need_check) {
3106            i = min(i, i_max);
3107        }
3108
3109        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
3110
3111        const float d = __half2float(bxi->d);
3112
3113        const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
3114            | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
3115
3116#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3117        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + threadIdx.x % 8] = d * (ls - 32);
3118#else
3119        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
3120#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3121    }
3122}
3123
3124template<int mmq_x, int mmq_y, bool need_check>
3125static __device__ __forceinline__ void mmq_write_back_dp4a(
3126        const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
3127        const int stride, const int i_max, const int j_max) {
3128    constexpr int nwarps = mmq_get_nwarps_device();
3129    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3130
3131#pragma unroll
3132    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3133        const int j = j0 + threadIdx.y;
3134
3135        if (j > j_max) {
3136            return;
3137        }
3138
3139#pragma unroll
3140        for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3141            const int i = i0 + threadIdx.x;
3142
3143            if (need_check && i > i_max) {
3144                continue;
3145            }
3146
3147            dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3148        }
3149    }
3150}
3151
3152template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
3153static __device__ __forceinline__ void mmq_write_back_mma(
3154        const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
3155        const int stride, const int i_max, const int j_max) {
3156
3157    constexpr int granularity = mmq_get_granularity_device(mmq_x);
3158    constexpr int nwarps = mmq_get_nwarps_device();
3159
3160#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3161    constexpr int tileC_IJ = mmq_get_granularity_device(0);
3162    typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
3163    constexpr int rows_per_warp = granularity;
3164#else
3165    typedef tile<16, 8, int> tile_C;
3166    constexpr int rows_per_warp = 2 * granularity;
3167#endif // defined(AMD_MFMA_AVAILABLE)
3168    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
3169
3170    const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
3171#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3172    static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
3173#else
3174    GGML_UNUSED(nwarps);
3175#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3176
3177#pragma unroll
3178    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
3179#pragma unroll
3180        for (int n = 0; n < ntx; ++n) {
3181#pragma unroll
3182            for (int l = 0; l < tile_C::ne; ++l) {
3183                const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
3184
3185                if (j > j_max) {
3186                    continue;
3187                }
3188
3189                const int i = i0 + n*tile_C::I + tile_C::get_i(l);
3190
3191                if (need_check && i > i_max) {
3192                    continue;
3193                }
3194
3195                dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
3196            }
3197        }
3198    }
3199}
3200
3201// -------------------------------------------------------------------------------------------------------------------------------------
3202
3203template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
3204struct mmq_type_traits;
3205
3206template <int mmq_x, int mmq_y, bool need_check>
3207struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
3208    static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;
3209    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0<mmq_y, need_check>;
3210    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
3211    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
3212};
3213
3214template <int mmq_x, int mmq_y, bool need_check>
3215struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
3216    static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;
3217    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1<mmq_y, need_check>;
3218    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3219    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
3220};
3221
3222template <int mmq_x, int mmq_y, bool need_check>
3223struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
3224    static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
3225    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0<mmq_y, need_check>;
3226    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3227    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3228};
3229
3230template <int mmq_x, int mmq_y, bool need_check>
3231struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
3232    static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;
3233    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1<mmq_y, need_check>;
3234    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3235    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
3236};
3237
3238template <int mmq_x, int mmq_y, bool need_check>
3239struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
3240    static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;
3241    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0<mmq_y, need_check>;
3242    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3243    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3244};
3245
3246template <int mmq_x, int mmq_y, bool need_check>
3247struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3248    static constexpr int              vdr          = VDR_MXFP4_Q8_1_MMQ;
3249#ifdef BLACKWELL_MMA_AVAILABLE
3250    static constexpr load_tiles_mmq_t load_tiles  = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3251    static constexpr vec_dot_mmq_t    vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3252#else
3253    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_mxfp4<mmq_y, need_check>;
3254    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3255#endif // BLACKWELL_MMA_AVAILABLE
3256    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3257};
3258
3259template <int mmq_x, int mmq_y, bool need_check>
3260struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
3261    static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
3262    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K<mmq_y, need_check>;
3263    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
3264    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
3265};
3266
3267template <int mmq_x, int mmq_y, bool need_check>
3268struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
3269    static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;
3270    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K<mmq_y, need_check>;
3271    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3272    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
3273};
3274
3275template <int mmq_x, int mmq_y, bool need_check>
3276struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
3277    static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;
3278    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K<mmq_y, need_check>;
3279    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3280    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
3281};
3282
3283template <int mmq_x, int mmq_y, bool need_check>
3284struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
3285    static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;
3286    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K<mmq_y, need_check>;
3287    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3288    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
3289};
3290
3291template <int mmq_x, int mmq_y, bool need_check>
3292struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
3293    static constexpr int              vdr          = VDR_Q6_K_Q8_1_MMQ;
3294    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K<mmq_y, need_check>;
3295    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
3296    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
3297};
3298
3299template <int mmq_x, int mmq_y, bool need_check>
3300struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
3301    static constexpr int              vdr          = VDR_IQ2_XXS_Q8_1_MMQ;
3302    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs<mmq_y, need_check>;
3303    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3304    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3305};
3306
3307template <int mmq_x, int mmq_y, bool need_check>
3308struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
3309    static constexpr int              vdr          = VDR_IQ2_XS_Q8_1_MMQ;
3310    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs<mmq_y, need_check>;
3311    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3312    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3313};
3314
3315template <int mmq_x, int mmq_y, bool need_check>
3316struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
3317    static constexpr int              vdr          = VDR_IQ2_S_Q8_1_MMQ;
3318    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s<mmq_y, need_check>;
3319    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3320    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3321};
3322
3323template <int mmq_x, int mmq_y, bool need_check>
3324struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
3325    static constexpr int              vdr          = VDR_IQ3_XXS_Q8_1_MMQ;
3326    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs<mmq_y, need_check>;
3327    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3328    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3329};
3330
3331template <int mmq_x, int mmq_y, bool need_check>
3332struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
3333    static constexpr int              vdr          = VDR_IQ3_S_Q8_1_MMQ;
3334    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s<mmq_y, need_check>;
3335    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3336    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3337};
3338
3339template <int mmq_x, int mmq_y, bool need_check>
3340struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
3341    static constexpr int              vdr          = VDR_IQ1_S_Q8_1_MMQ;
3342    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s<mmq_y, need_check>;
3343    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3344    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
3345};
3346
3347template <int mmq_x, int mmq_y, bool need_check>
3348struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
3349    static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
3350    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl<mmq_y, need_check>;
3351    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3352    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3353};
3354
3355template <int mmq_x, int mmq_y, bool need_check>
3356struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
3357    static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
3358    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs<mmq_y, need_check>;
3359    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3360    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3361};
3362
3363template <ggml_type type, int mmq_x, bool need_check, bool fixup>
3364static __device__ __forceinline__ void mul_mat_q_process_tile(
3365        const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
3366        const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3367        const int stride_row_x, const int ncols_y, const int stride_col_dst,
3368        const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
3369
3370    constexpr int              warp_size  = ggml_cuda_get_physical_warp_size();
3371    constexpr int              nwarps     = mmq_get_nwarps_device();
3372    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
3373    constexpr int              mmq_y      = get_mmq_y_device();
3374    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
3375
3376    extern __shared__ int data_mul_mat_q[];
3377    int * tile_y = data_mul_mat_q + mmq_x;
3378    int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3379
3380#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3381    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3382    constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3383#else
3384    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3385    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3386#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3387
3388#if defined(BLACKWELL_MMA_AVAILABLE)
3389    // FP4 tile stores 8 blocks
3390    constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
3391#else
3392    constexpr int ne_block = 4 * QK8_1;
3393#endif  // defined(BLACKWELL_MMA_AVAILABLE)
3394
3395    constexpr int ITER_K          = get_iter_k(type);
3396    constexpr int blocks_per_iter = ITER_K / qk;
3397
3398    float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3399
3400    constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
3401
3402    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
3403        load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
3404        {
3405            const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
3406#pragma unroll
3407            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3408                int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3409
3410                tile_y[l] = by0[l];
3411            }
3412        }
3413
3414        __syncthreads();
3415
3416        vec_dot(tile_x, tile_y, sum, 0);
3417
3418        __syncthreads();
3419
3420        {
3421            const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
3422#pragma unroll
3423            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3424                int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3425
3426                tile_y[l] = by0[l];
3427            }
3428        }
3429
3430        __syncthreads();
3431
3432        vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
3433
3434        __syncthreads();
3435    }
3436
3437    if (fixup) {
3438        write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
3439    } else {
3440        write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
3441    }
3442}
3443
3444
3445// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
3446
3447template <ggml_type type, int mmq_x, bool need_check>
3448#if defined(GGML_USE_HIP)
3449#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3450    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
3451#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3452#else
3453#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3454    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
3455#else
3456    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
3457#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3458#endif // defined(GGML_USE_HIP)
3459static __global__ void mul_mat_q(
3460        const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
3461        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3462        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3463        const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3464        const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3465        const int ncols_max) {
3466
3467    // Skip unused template specializations for faster compilation:
3468    if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
3469        NO_DEVICE_CODE;
3470        return;
3471    }
3472
3473    constexpr int nwarps = mmq_get_nwarps_device();
3474    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3475
3476    constexpr int qk    = ggml_cuda_type_traits<type>::qk;
3477    constexpr int mmq_y = get_mmq_y_device();
3478
3479    const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
3480    const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y
3481
3482    // Initialize the ids for writing back data with just the index.
3483    // For regular matrix multiplications this is never changed.
3484    // For MoE the correct indices are loaded from ids_dst.
3485    extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
3486#pragma unroll
3487    for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3488        const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3489
3490        if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3491            break;
3492        }
3493
3494        ids_dst_shared[j] = j;
3495    }
3496    __syncthreads();
3497
3498    // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3499#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3500    {
3501        const int wt = blockIdx.z / nchannels_y;
3502        const int zt = blockIdx.z - wt*nchannels_y;
3503        const int jt = blockIdx.y;
3504        const int it = blockIdx.x;
3505
3506        // Defaults for regular matrix multiplication:
3507        int col_low    = 0;
3508        int col_high   = ncols_dst;
3509        int col_diff   = ncols_dst;
3510        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
3511        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3512
3513        if (ids_dst) {
3514            col_low  = expert_bounds[zt + 0];
3515            col_high = expert_bounds[zt + 1];
3516            col_diff = col_high - col_low;
3517
3518            offset_y   = 0;
3519            offset_dst = 0;
3520
3521            if (jt*mmq_x >= col_diff) {
3522                return;
3523            }
3524
3525            // __syncthreads(); // There is no previous tile that could cause a race condition.
3526#pragma unroll
3527            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3528                const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3529
3530                if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3531                    break;
3532                }
3533
3534                ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3535            }
3536            __syncthreads();
3537        }
3538
3539        offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3540        offset_dst += it*mmq_y;
3541
3542        const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
3543        const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3544
3545        const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3546
3547        constexpr bool fixup = false;
3548        mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3549            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3550             tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
3551        return;
3552    }
3553#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3554
3555    constexpr int ITER_K = get_iter_k(type);
3556
3557    const     int64_t blocks_per_ne00 = ncols_x / qk;
3558    constexpr int     blocks_per_iter = ITER_K / qk;
3559
3560    // kbc == k block continuous, current index in continuous ijk space.
3561    int64_t kbc      = (int64_t) blockIdx.x     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3562    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3563
3564    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_iter;
3565    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
3566
3567    // kb0 == k index when doing the matrix multiplication for an output tile.
3568    int kb0_start = kbc % blocks_per_ne00;
3569    int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
3570    while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
3571        int tmp = kbc;
3572        const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3573        tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3574        const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3575        tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3576        const int zt = tmp / (ntx*blocks_per_ne00);
3577        tmp -= zt * (ntx*blocks_per_ne00);
3578        const int jt = tmp / blocks_per_ne00;
3579
3580        // Defaults for regular matrix multiplication:
3581        int col_low    = 0;
3582        int col_high   = ncols_dst;
3583        int col_diff   = ncols_dst;
3584        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
3585        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3586
3587        if (ids_dst) {
3588            col_low  = expert_bounds[zt + 0];
3589            col_high = expert_bounds[zt + 1];
3590            col_diff = col_high - col_low;
3591
3592            offset_y   = 0;
3593            offset_dst = 0;
3594
3595            if (jt*mmq_x >= col_diff) {
3596                kbc += blocks_per_ne00;
3597                kbc -= kbc % blocks_per_ne00;
3598
3599                kb0_start = 0;
3600                kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
3601
3602                continue;
3603            }
3604
3605            __syncthreads();
3606#pragma unroll
3607            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3608                const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3609
3610                if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3611                    break;
3612                }
3613
3614                ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3615            }
3616            __syncthreads();
3617        }
3618
3619        offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3620        offset_dst += it*mmq_y;
3621
3622        const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
3623        const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3624
3625        const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3626
3627        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
3628        mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3629            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3630             tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3631
3632        kbc += blocks_per_ne00;
3633        kbc -= kbc % blocks_per_ne00;
3634
3635        kb0_start = 0;
3636        kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
3637    }
3638
3639    if (kbc >= kbc_stop) {
3640        return;
3641    }
3642
3643    int tmp = kbc;
3644    const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3645    tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3646    const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3647    tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3648    const int zt = tmp / (ntx*blocks_per_ne00);
3649    tmp -= zt * (ntx*blocks_per_ne00);
3650    const int jt = tmp / blocks_per_ne00;
3651
3652    // Defaults for regular matrix multiplication:
3653    int col_low    = 0;
3654    int col_high   = ncols_dst;
3655    int col_diff   = ncols_dst;
3656    int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
3657    int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3658
3659    if (ids_dst) {
3660        col_low  = expert_bounds[zt + 0];
3661        col_high = expert_bounds[zt + 1];
3662        col_diff = col_high - col_low;
3663
3664        offset_y   = 0;
3665        offset_dst = 0;
3666
3667        if (jt*mmq_x >= col_diff) {
3668            return;
3669        }
3670
3671        // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
3672        __syncthreads();
3673#pragma unroll
3674        for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3675            const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3676
3677            if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3678                break;
3679            }
3680
3681            ids_dst_shared[j] = j;
3682        }
3683        __syncthreads();
3684    }
3685
3686    offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3687    offset_dst += it*mmq_y;
3688
3689    const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
3690    const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3691
3692    const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3693
3694    constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
3695    mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3696        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3697         tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3698}
3699
3700template <ggml_type type, int mmq_x, bool need_check>
3701static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3702                                                const int32_t * expert_bounds,
3703                                                float * __restrict__ dst,
3704                                                const float * __restrict__ tmp_last_tile,
3705                                                const int    ncols_x,
3706                                                const int    nrows_x,
3707                                                const int    ncols_dst,
3708                                                const size_t stride_col_dst,
3709                                                const int    nchannels_y,
3710                                                const size_t stride_channel_dst,
3711                                                const int    nsamples_y,
3712                                                const size_t stride_sample_dst,
3713                                                const int    ncols_max) {
3714    constexpr int     mmq_y           = get_mmq_y_device();
3715    constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
3716    constexpr int     ITER_K          = get_iter_k(type);
3717
3718    constexpr int     blocks_per_iter = ITER_K / qk;
3719    const     int64_t blocks_per_ne00 = ncols_x / qk;
3720
3721    constexpr int nwarps = mmq_get_nwarps_device();
3722    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3723
3724    float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3725
3726    const int ntx  = (ncols_max + mmq_x - 1) / mmq_x;
3727    const int nty  = (nrows_x   + mmq_y - 1) / mmq_y;
3728
3729    const int bidx0 = blockIdx.x;
3730
3731    // kbc == k block continuous, current index in continuous ijk space.
3732    int64_t kbc0      = (int64_t) bidx0     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3733    int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3734
3735    kbc0      -= (kbc0      % blocks_per_ne00) % blocks_per_iter;
3736    kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
3737
3738    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
3739    const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
3740    const bool did_not_write_last      = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
3741    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
3742        return;
3743    }
3744
3745    bool any_fixup = false;
3746
3747    // Iterate over previous blocks and sum up partial sums written to fixup buffer.
3748    // All CUDA blocks that get here must have a previous block that needs a fixup.
3749    int64_t bidx = bidx0 - 1;
3750    int64_t kbc_stop = kbc0;
3751    while(true) {
3752        int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3753        kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3754
3755        if (kbc == kbc_stop) { // Did not have any data.
3756            bidx--;
3757            kbc_stop = kbc;
3758            continue;
3759        }
3760
3761        any_fixup = true;
3762
3763#pragma unroll
3764        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3765            const int j = j0 + threadIdx.y;
3766
3767#pragma unroll
3768            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3769                const int i = i0 + threadIdx.x;
3770
3771                sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3772            }
3773        }
3774
3775        // If this block started in a previous tile we are done and don't need to combine additional partial results.
3776        if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
3777            break;
3778        }
3779        bidx--;
3780        kbc_stop = kbc;
3781    }
3782
3783    if (!any_fixup) {
3784        return;
3785    }
3786
3787    int tmp = kbc0;
3788    const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3789    tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3790    const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3791    tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3792    const int zt = tmp / (ntx*blocks_per_ne00);
3793    tmp -= zt * (ntx*blocks_per_ne00);
3794    const int jt = tmp / blocks_per_ne00;
3795
3796    if (!ids_dst) {
3797        const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
3798        dst += offset_dst;
3799
3800        const int i_max = nrows_x   - it*mmq_y - 1;
3801        const int j_max = ncols_dst - jt*mmq_x - 1;
3802
3803#pragma unroll
3804        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3805            const int j = j0 + threadIdx.y;
3806
3807            if (j > j_max) {
3808                return;
3809            }
3810
3811#pragma unroll
3812            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3813                const int i = i0 + threadIdx.x;
3814
3815                if (need_check && i > i_max) {
3816                    continue;
3817                }
3818
3819                dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3820            }
3821        }
3822        return;
3823    }
3824
3825    __shared__ int ids_dst_shared[mmq_x];
3826    const int col_low  = expert_bounds[zt + 0];
3827    const int col_high = expert_bounds[zt + 1];
3828    const int col_diff = col_high - col_low;
3829
3830    for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3831        ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3832    }
3833    __syncthreads();
3834
3835    const int offset_dst = it*mmq_y;
3836    dst += offset_dst;
3837
3838    const int i_max = nrows_x  - it*mmq_y - 1;
3839    const int j_max = col_diff - jt*mmq_x - 1;
3840
3841#pragma unroll
3842    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3843        const int j = j0 + threadIdx.y;
3844
3845        if (j > j_max) {
3846            return;
3847        }
3848
3849#pragma unroll
3850        for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3851            const int i = i0 + threadIdx.x;
3852
3853            if (need_check && i > i_max) {
3854                continue;
3855            }
3856
3857            dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3858        }
3859    }
3860}
3861
3862struct mmq_args {
3863    const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
3864    int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
3865    int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
3866    int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
3867    bool use_stream_k; int64_t ncols_max;
3868};
3869
3870template<ggml_type type>
3871static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
3872    const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3873    const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3874    const size_t nbs_ids = mmq_x*sizeof(int);
3875    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3876    const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
3877    return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3878}
3879
3880template <ggml_type type, int mmq_x>
3881static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3882    const int id = ggml_cuda_get_device();
3883    const int cc = ggml_cuda_info().devices[id].cc;
3884    const int nsm = ggml_cuda_info().devices[id].nsm;
3885    const int warp_size = ggml_cuda_info().devices[id].warp_size;
3886    const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3887    const int mmq_y = get_mmq_y_host(cc);
3888
3889    const dim3 block_dims(warp_size, nwarps, 1);
3890
3891    const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
3892
3893    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
3894    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x,  true>), nbytes_shared);
3895
3896    const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;
3897    const int ntx  = (args.ncols_max + mmq_x - 1) / mmq_x;
3898    const int ntzw = args.nchannels_y * args.nsamples_y;
3899    const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3900
3901    GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
3902    GGML_ASSERT(args.nsamples_y  % args.nsamples_x  == 0);
3903    const int channel_ratio = args.nchannels_y / args.nchannels_x;
3904    const int sample_ratio  = args.nsamples_y  / args.nsamples_x;
3905
3906    if (!args.use_stream_k) {
3907        if (args.nrows_x % mmq_y == 0) {
3908            constexpr bool need_check = false;
3909            mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3910                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3911                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3912                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3913                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3914                 args.ncols_max);
3915        } else {
3916            constexpr bool need_check = true;
3917            mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3918                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3919                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3920                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3921                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3922                 args.ncols_max);
3923        }
3924        return;
3925    }
3926
3927    const dim3 block_nums_stream_k(nsm, 1, 1);
3928    const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
3929
3930    ggml_cuda_pool & pool = ctx.pool(id);
3931    ggml_cuda_pool_alloc<float> tmp_fixup(pool);
3932    if (fixup_needed) {
3933        tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
3934    }
3935
3936    if (args.nrows_x % mmq_y == 0) {
3937        constexpr bool need_check = false;
3938        mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3939            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3940             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3941             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3942             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3943             args.ncols_max);
3944
3945        if (!fixup_needed) {
3946            return;
3947        }
3948
3949        mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3950            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3951             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3952             args.ncols_max);
3953    } else {
3954        constexpr bool need_check = true;
3955        mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3956            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3957             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3958             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3959             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3960             args.ncols_max);
3961
3962        if (!fixup_needed) {
3963            return;
3964        }
3965
3966        mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3967            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3968             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3969             args.ncols_max);
3970    }
3971}
3972
3973template <ggml_type type>
3974void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3975    const int    id     = ggml_cuda_get_device();
3976    const int    cc     = ggml_cuda_info().devices[id].cc;
3977    const size_t smpbo  = ggml_cuda_info().devices[id].smpbo;
3978    const int warp_size = ggml_cuda_info().devices[id].warp_size;
3979    const int nwarps    = mmq_get_nwarps_host(cc, warp_size);
3980
3981    const int mmq_x_max = get_mmq_x_max_host(cc);
3982    const int mmq_y = get_mmq_y_host(cc);
3983
3984    int mmq_x_best  = 0;
3985    int ntiles_x_best = INT_MAX;
3986
3987    for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3988        const int granularity = mmq_get_granularity_host(mmq_x, cc);
3989
3990        if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
3991            continue;
3992        }
3993
3994        const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
3995
3996        if (ntiles_x < ntiles_x_best) {
3997            mmq_x_best = mmq_x;
3998            ntiles_x_best = ntiles_x;
3999        }
4000    }
4001
4002    switch (mmq_x_best) {
4003        case   8:
4004            launch_mul_mat_q<type,   8>(ctx, args, stream);
4005            break;
4006        case  16:
4007            launch_mul_mat_q<type,  16>(ctx, args, stream);
4008            break;
4009        case  24:
4010            launch_mul_mat_q<type,  24>(ctx, args, stream);
4011            break;
4012        case  32:
4013            launch_mul_mat_q<type,  32>(ctx, args, stream);
4014            break;
4015        case  40:
4016            launch_mul_mat_q<type,  40>(ctx, args, stream);
4017            break;
4018        case  48:
4019            launch_mul_mat_q<type,  48>(ctx, args, stream);
4020            break;
4021        case  56:
4022            launch_mul_mat_q<type,  56>(ctx, args, stream);
4023            break;
4024        case  64:
4025            launch_mul_mat_q<type,  64>(ctx, args, stream);
4026            break;
4027        case  72:
4028            launch_mul_mat_q<type,  72>(ctx, args, stream);
4029            break;
4030        case  80:
4031            launch_mul_mat_q<type,  80>(ctx, args, stream);
4032            break;
4033        case  88:
4034            launch_mul_mat_q<type,  88>(ctx, args, stream);
4035            break;
4036        case  96:
4037            launch_mul_mat_q<type,  96>(ctx, args, stream);
4038            break;
4039        case 104:
4040            launch_mul_mat_q<type, 104>(ctx, args, stream);
4041            break;
4042        case 112:
4043            launch_mul_mat_q<type, 112>(ctx, args, stream);
4044            break;
4045        case 120:
4046            launch_mul_mat_q<type, 120>(ctx, args, stream);
4047            break;
4048        case 128:
4049            launch_mul_mat_q<type, 128>(ctx, args, stream);
4050            break;
4051        default:
4052            fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
4053            GGML_ABORT("fatal error");
4054            break;
4055    }
4056}
4057
4058#define DECL_MMQ_CASE(type)                                                        \
4059    template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
4060
4061extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
4062extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
4063extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
4064extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
4065extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
4066extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
4067extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
4068extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
4069extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
4070extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
4071extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
4072extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
4073extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
4074extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
4075extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
4076extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
4077extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
4078extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
4079extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
4080
4081// -------------------------------------------------------------------------------------------------------------------------
4082
4083void ggml_cuda_mul_mat_q(
4084        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
4085
4086void ggml_cuda_op_mul_mat_q(
4087    ggml_backend_cuda_context & ctx,
4088    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
4089    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
4090    const int64_t src1_padded_row_size, cudaStream_t stream);
4091
4092bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);