1#pragma once
   2// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
   3// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
   4// The documentation for the PTX instructions can be found under:
   5//   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
   6//
   7// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
   8// A is a row-major matrix with shape M x K.
   9// B is a column-major matrix with shape K x N.
  10// C is a column-major matrix with shape M x N.
  11// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
  12// Note that J is measured in physical 32 bit elements instead of logical elements.
  13// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
  14// All matrix tiles have ne physical 32 bit elements per warp.
  15//
  16// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
  17// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
  18
  19#include "common.cuh"
  20
  21// On Volta each warp is doing 4 8x8 mma operations in parallel.
  22// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
  23// However, the i indices in this file are by default permuted to simplify the index calculations.
  24// #define GGML_CUDA_MMA_NO_VOLTA_PERM
  25
  26#if CUDART_VERSION >= 11080
  27
  28static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
  29    int ret = 0;
  30
  31#ifdef TURING_MMA_AVAILABLE
  32    asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
  33        : "=r"(ret) : "r"(x));
  34#else
  35    GGML_UNUSED(x);
  36    NO_DEVICE_CODE;
  37#endif // defined(TURING_MMA_AVAILABLE)
  38    return ret;
  39}
  40
  41#else
  42
  43static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
  44    // Imagine transposing row-major matrix to column-major matrix.
  45    const int src_i_low  = 2 * (threadIdx.x % 4);
  46    const int src_i_high = src_i_low + 1;
  47    const int src_j      = threadIdx.x / 4;
  48
  49    const int src_laneid_low  = src_i_low  * 4 + src_j / 2;
  50    const int src_laneid_high = src_i_high * 4 + src_j / 2;
  51
  52    const int shift_low  = ((src_j + 0) % 2) * 16;
  53    const int shift_high = ((src_j + 1) % 2) * 16;
  54
  55    const int ret_low  = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low,  WARP_SIZE) >> shift_low)  & 0x0000FFFF;
  56    const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
  57
  58    return ret_low | ret_high;
  59}
  60
  61#endif // CUDART_VERSION >= 11080
  62
  63static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
  64    half2 ret;
  65    *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
  66    return ret;
  67}
  68
  69namespace ggml_cuda_mma {
  70
  71    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
  72    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
  73    // In those cases the data can be split in different ways across the warp.
  74    enum data_layout {
  75        // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
  76        // For the A/C matrices this means I major == row major, J major == column major.
  77        // For the B matrix this means I major == column major, J major == row major.
  78        // MIRRORED == Each data value is held exactly once per thread subgroup.
  79        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
  80        DATA_LAYOUT_J_MAJOR           = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
  81        DATA_LAYOUT_I_MAJOR_MIRRORED  = 20, // Volta, matrix A&B for RDNA3.
  82        DATA_LAYOUT_J_MAJOR_MIRRORED  = 30,
  83    };
  84    // Implemented mma combinations are:
  85    //   - (I_MAJOR, I_MAJOR)          -> I_MAJOR
  86    //   - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
  87    //   - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
  88
  89    static constexpr bool is_i_major(const data_layout dl) {
  90        return dl == DATA_LAYOUT_I_MAJOR ||
  91               dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
  92    }
  93
  94    static constexpr __device__ data_layout get_input_data_layout() {
  95#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
  96        return DATA_LAYOUT_I_MAJOR_MIRRORED;
  97#else
  98        return DATA_LAYOUT_I_MAJOR;
  99#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 100    }
 101
 102    template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
 103    struct tile {};
 104
 105    template <int I_, int J_, typename T>
 106    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
 107        static constexpr int         I  = I_;
 108        static constexpr int         J  = J_;
 109        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 110
 111#if defined(AMD_MFMA_AVAILABLE)
 112        static constexpr int ne = I * J / 64;
 113        T x[ne] = {0};
 114
 115        static constexpr __device__ bool supported() {
 116            if (I == 64 && J ==  2) return true;
 117            if (I == 16 && J ==  8) return true;
 118            if (I == 32 && J ==  4) return true;
 119            if (I == 16 && J == 16) return true;
 120            if (I == 32 && J == 32) return true;
 121            return false;
 122        }
 123
 124        static __device__ __forceinline__ int get_i(const int l) {
 125            if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
 126                return threadIdx.x % 16;
 127            } else if constexpr (I == 16 && J == 8) {
 128                return threadIdx.x % 16;
 129            } else if constexpr (I == 32 && J == 4) {
 130                return threadIdx.x % 32;
 131            } else if constexpr (I == 16 && J == 16) {
 132                return threadIdx.x % 16;
 133            } else if constexpr (I == 32 && J == 32) {
 134                return threadIdx.x % 32;
 135            } else {
 136                NO_DEVICE_CODE;
 137                return -1;
 138            }
 139        }
 140
 141        static __device__ __forceinline__ int get_j(const int l) {
 142            if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
 143                return (2 * ((threadIdx.x / 16) % 2) + l);
 144            } else if constexpr (I == 16 && J == 8) {
 145                return 2 * (threadIdx.x / 16) + l;
 146            } else if constexpr (I == 32 && J == 4) {
 147                return 2 * (threadIdx.x / 32) + l;
 148            } else if constexpr (I == 16 && J == 16) {
 149                return 4 * (threadIdx.x / 16) + l;
 150            } else if constexpr (I == 32 && J == 32) {
 151                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
 152            } else {
 153                NO_DEVICE_CODE;
 154                return -1;
 155            }
 156        }
 157#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 158        static constexpr int ne = I * J / 32;
 159        T x[ne] = {0};
 160
 161        static constexpr __device__ bool supported() {
 162            if (I == 32 && J ==  8) return true;
 163            return false;
 164        }
 165
 166        static __device__ __forceinline__ int get_i(const int l) {
 167            if constexpr (I == 32 && J == 8) {
 168#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
 169                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
 170#else
 171                return (l & 2) + (threadIdx.x & ~2);
 172#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
 173            } else {
 174                NO_DEVICE_CODE;
 175                return -1;
 176            }
 177        }
 178
 179        static __device__ __forceinline__ int get_j(const int l) {
 180            if constexpr (I == 32 && J == 8) {
 181                return (threadIdx.x & 2) + (l & (4 + 1));
 182            } else {
 183                NO_DEVICE_CODE;
 184                return -1;
 185            }
 186        }
 187#elif defined(AMD_WMMA_AVAILABLE)
 188        static constexpr int ne = I * J / 32;
 189        T x[ne] = {0};
 190
 191        static constexpr __device__ bool supported() {
 192            if (I == 16 && J == 16) return true;
 193            if (I == 16 && J == 8) return true;
 194            if (I == 16 && J == 4) return true;
 195            return false;
 196        }
 197
 198        static __device__ __forceinline__ int get_i(const int l) {
 199            if constexpr (supported()) {
 200                return threadIdx.x % 16;
 201            } else {
 202                NO_DEVICE_CODE;
 203                return -1;
 204            }
 205        }
 206
 207        static __device__ __forceinline__ int get_j(const int l) {
 208            if constexpr (I == 16 && J == 16) {
 209#if defined(RDNA3)
 210                if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
 211                    // matrix C
 212                    return 2 * l + (threadIdx.x / 16);
 213                } else {
 214                    // matrix A&B
 215                    return l;
 216                }
 217#else
 218                // matrix C is the transposed matrix A&B on RDNA4
 219                return ne * (threadIdx.x / 16) + l;
 220#endif // defined(RDNA3)
 221            } else if constexpr (I == 16 && J == 8) {
 222                // mmq input for RDNA4
 223                return ne * (threadIdx.x / 16) + l;
 224            } else if constexpr (I == 16 && J == 4) {
 225                return ne * (threadIdx.x / 16) + l;
 226            } else {
 227                NO_DEVICE_CODE;
 228                return -1;
 229            }
 230        }
 231#else
 232        static constexpr int ne = I * J / 32;
 233        T x[ne] = {0};
 234
 235        static constexpr __device__ bool supported() {
 236            if (I ==  8 && J ==  4) return true;
 237            if (I ==  8 && J ==  8) return true;
 238            if (I == 16 && J ==  8) return true;
 239            if (I == 16 && J == 16) return true;
 240            if (I == 32 && J ==  8) return true;
 241            return false;
 242        }
 243
 244        static __device__ __forceinline__ int get_i(const int l) {
 245            if constexpr (I == 8 && J == 4) {
 246                return threadIdx.x / 4;
 247            } else if constexpr (I == 8 && J == 8) {
 248                return threadIdx.x / 4;
 249            } else if constexpr (I == 16 && J == 8) {
 250                return ((l / 2) * 8) + (threadIdx.x / 4);
 251            } else if constexpr (I == 16 && J == 16) {
 252                return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
 253            } else if constexpr (I == 32 && J == 8) {
 254                return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
 255            } else {
 256                NO_DEVICE_CODE;
 257                return -1;
 258            }
 259        }
 260
 261        static __device__ __forceinline__ int get_j(const int l) {
 262            if constexpr (I == 8 && J == 4) {
 263                return threadIdx.x % 4;
 264            } else if constexpr (I == 8 && J == 8) {
 265                return (l * 4) + (threadIdx.x % 4);
 266            } else if constexpr (I == 16 && J == 8) {
 267                return ((threadIdx.x % 4) * 2) + (l % 2);
 268            } else if constexpr (I == 16 && J == 16) {
 269                return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
 270            } else if constexpr (I == 32 && J == 8) {
 271                return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
 272            } else {
 273                NO_DEVICE_CODE;
 274                return -1;
 275            }
 276        }
 277#endif // defined(GGML_USE_HIP)
 278    };
 279
 280    template <int I_, int J_>
 281    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
 282        static constexpr int         I  = I_;
 283        static constexpr int         J  = J_;
 284        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 285
 286#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 287        static constexpr int ne = I * J / WARP_SIZE;
 288        half2 x[ne] = {{0.0f, 0.0f}};
 289
 290        static constexpr __device__ bool supported() {
 291            if (I == 32 && J ==  4) return true;
 292            return false;
 293        }
 294
 295        static __device__ __forceinline__ int get_i(const int l) {
 296            if constexpr (I == 32 && J == 4) {
 297#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
 298                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
 299#else
 300                return threadIdx.x;
 301#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
 302            } else {
 303                NO_DEVICE_CODE;
 304                return -1;
 305            }
 306        }
 307
 308        static __device__ __forceinline__ int get_j(const int l) {
 309            if constexpr (I == 32 && J == 4) {
 310                return l;
 311            } else {
 312                NO_DEVICE_CODE;
 313                return -1;
 314            }
 315        }
 316#elif defined(AMD_WMMA_AVAILABLE)
 317        static constexpr int ne = I * J / 32;
 318        half2 x[ne] = {{0.0f, 0.0f}};
 319
 320        static constexpr __device__ bool supported() {
 321            if (I == 16 && J == 8) return true;
 322            return false;
 323        }
 324
 325        static __device__ __forceinline__ int get_i(const int l) {
 326            if constexpr (I == 16 && J == 8) {
 327                return threadIdx.x % 16;
 328            } else {
 329                NO_DEVICE_CODE;
 330                return -1;
 331            }
 332        }
 333
 334        static __device__ __forceinline__ int get_j(const int l) {
 335            if constexpr (I == 16 && J == 8) {
 336                return ne * (threadIdx.x / 16) + l;
 337            } else {
 338                NO_DEVICE_CODE;
 339                return -1;
 340            }
 341        }
 342#elif defined(AMD_MFMA_AVAILABLE)
 343        static constexpr int ne = I * J / 64;
 344        half2 x[ne] = {{0.0f, 0.0f}};
 345
 346        static constexpr __device__ bool supported() {
 347            if (I == 16 && J == 8) return true;
 348            return false;
 349        }
 350
 351        static __device__ __forceinline__ int get_i(const int l) {
 352            if constexpr (I == 16 && J == 8) {
 353                return threadIdx.x % 16;
 354            } else {
 355                NO_DEVICE_CODE;
 356                return -1;
 357            }
 358        }
 359
 360        static __device__ __forceinline__ int get_j(const int l) {
 361            if constexpr (I == 16 && J == 8) {
 362                return ne * (threadIdx.x / 16) + l;
 363            } else {
 364                NO_DEVICE_CODE;
 365                return -1;
 366            }
 367        }
 368#else
 369        static constexpr int ne = I * J / WARP_SIZE;
 370        half2 x[ne] = {{0.0f, 0.0f}};
 371
 372        static constexpr __device__ bool supported() {
 373            if (I ==  8 && J ==  4) return true;
 374            if (I ==  8 && J ==  8) return true;
 375            if (I == 16 && J ==  8) return true;
 376            if (I == 16 && J == 16) return true;
 377            if (I == 32 && J ==  8) return true;
 378            return false;
 379        }
 380
 381        static __device__ __forceinline__ int get_i(const int l) {
 382            if constexpr (I == 8 && J == 8) {
 383                return threadIdx.x / 4;
 384            } else if constexpr (I == 16 && J == 4) {
 385                return (l * 8) + (threadIdx.x / 4);
 386            } else if constexpr (I == 16 && J == 8) {
 387                return ((l % 2) * 8) + (threadIdx.x / 4);
 388            } else if constexpr (I == 32 && J == 8) {
 389                return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
 390            } else {
 391                NO_DEVICE_CODE;
 392                return -1;
 393            }
 394        }
 395
 396        static __device__ __forceinline__ int get_j(const int l) {
 397            if constexpr (I == 8 && J == 8) {
 398                return (l * 4) + (threadIdx.x % 4);
 399            } else if constexpr (I == 16 && J == 4) {
 400                return threadIdx.x % 4;
 401            } else if constexpr (I == 16 && J == 8) {
 402                return ((l / 2) * 4) + (threadIdx.x % 4);
 403            } else if constexpr (I == 32 && J == 8) {
 404                return ((l & 2) * 2) + (threadIdx.x % 4);
 405            } else {
 406                NO_DEVICE_CODE;
 407                return -1;
 408            }
 409        }
 410#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 411    };
 412
 413    template <int I_, int J_>
 414    struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
 415        static constexpr int         I  = I_;
 416        static constexpr int         J  = J_;
 417        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 418
 419#if defined(AMD_WMMA_AVAILABLE)
 420        static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
 421        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 422
 423        static constexpr __device__ bool supported() {
 424            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
 425        }
 426
 427        static __device__ __forceinline__ int get_i(const int l) {
 428            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
 429        }
 430
 431        static __device__ __forceinline__ int get_j(const int l) {
 432            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
 433        }
 434#elif defined(AMD_MFMA_AVAILABLE)
 435        static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
 436        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 437
 438        static constexpr __device__ bool supported() {
 439            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
 440        }
 441
 442        static __device__ __forceinline__ int get_i(const int l) {
 443            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
 444        }
 445
 446        static __device__ __forceinline__ int get_j(const int l) {
 447            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
 448        }
 449#else
 450        static constexpr int ne = I * J / WARP_SIZE;
 451        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 452
 453        static constexpr __device__ bool supported() {
 454            if (I ==  8 && J ==  8) return true;
 455            if (I == 16 && J ==  4) return true;
 456            if (I == 16 && J ==  8) return true;
 457            return false;
 458        }
 459
 460        static __device__ __forceinline__ int get_i(const int l) {
 461            if constexpr (I == 8 && J == 8) {
 462                return threadIdx.x / 4;
 463            } else if constexpr (I == 16 && J == 4) {
 464                return (l * 8) + (threadIdx.x / 4);
 465            } else if constexpr (I == 16 && J == 8) {
 466                return ((l % 2) * 8) + (threadIdx.x / 4);
 467            } else {
 468                NO_DEVICE_CODE;
 469                return -1;
 470            }
 471        }
 472
 473        static __device__ __forceinline__ int get_j(const int l) {
 474            if constexpr (I == 8 && J == 8) {
 475                return (l * 4) + (threadIdx.x % 4);
 476            } else if constexpr (I == 16 && J == 4) {
 477                return threadIdx.x % 4;
 478            } else if constexpr (I == 16 && J == 8) {
 479                return ((l / 2) * 4) + (threadIdx.x % 4);
 480            } else {
 481                NO_DEVICE_CODE;
 482                return -1;
 483            }
 484        }
 485#endif  // defined(AMD_WMMA_AVAILABLE)
 486    };
 487
 488    template <int I_, int J_, typename T>
 489    struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
 490        static constexpr int         I  = I_;
 491        static constexpr int         J  = J_;
 492        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
 493
 494        static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
 495        T x[ne] = {0};
 496
 497        static constexpr __device__ bool supported() {
 498            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
 499        }
 500
 501        static __device__ __forceinline__ int get_i(const int l) {
 502            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
 503        }
 504
 505        static __device__ __forceinline__ int get_j(const int l) {
 506            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
 507        }
 508    };
 509
 510    template <int I_, int J_, typename T>
 511    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
 512        static constexpr int         I  = I_;
 513        static constexpr int         J  = J_;
 514        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
 515
 516        // RDNA3
 517        static constexpr int         ne = I * J / 32 * 2;
 518
 519        T x[ne] = {0};
 520
 521        static constexpr __device__ bool supported() {
 522            if (I == 16 && J == 16) return true;
 523            if (I == 16 && J == 8)  return true;
 524            if (I == 16 && J == 4)  return true;
 525            return false;
 526        }
 527
 528        static __device__ __forceinline__ int get_i(const int /*l*/) {
 529            if constexpr (supported()) {
 530                return threadIdx.x % 16;
 531            } else {
 532                NO_DEVICE_CODE;
 533                return -1;
 534            }
 535        }
 536
 537        static __device__ __forceinline__ int get_j(const int l) {
 538            if constexpr (supported()) {
 539                return l;
 540            } else {
 541                NO_DEVICE_CODE;
 542                return -1;
 543            }
 544        }
 545    };
 546
 547    template <int I_, int J_>
 548    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
 549        static constexpr int         I  = I_;
 550        static constexpr int         J  = J_;
 551        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
 552#if defined(RDNA3)
 553        static constexpr int         ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
 554
 555        half2 x[ne] = {{0.0f, 0.0f}};
 556
 557        static constexpr __device__ bool supported() {
 558            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
 559        }
 560
 561        static __device__ __forceinline__ int get_i(const int l) {
 562            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
 563        }
 564
 565        static __device__ __forceinline__ int get_j(const int l) {
 566            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
 567        }
 568#else // Volta
 569        static constexpr int         ne = I * J / (WARP_SIZE/4);
 570
 571        half2 x[ne] = {{0.0f, 0.0f}};
 572
 573        static constexpr __device__ bool supported() {
 574            if (I ==  8 && J ==  4) return true;
 575            return false;
 576        }
 577
 578        static __device__ __forceinline__ int get_i(const int /*l*/) {
 579            if constexpr (I == 8 && J == 4) {
 580                return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
 581            } else {
 582                NO_DEVICE_CODE;
 583                return -1;
 584            }
 585        }
 586
 587        static __device__ __forceinline__ int get_j(const int l) {
 588            if constexpr (I == 8 && J == 4) {
 589                return l;
 590            } else {
 591                NO_DEVICE_CODE;
 592                return -1;
 593            }
 594        }
 595#endif // defined(RDNA3)
 596    };
 597
 598    template <int I_, int J_>
 599    struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
 600        static constexpr int         I  = I_;
 601        static constexpr int         J  = J_;
 602        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
 603        static constexpr int         ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
 604
 605        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 606
 607        static constexpr __device__ bool supported() {
 608            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
 609        }
 610
 611        static __device__ __forceinline__ int get_i(const int l) {
 612            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
 613        }
 614
 615        static __device__ __forceinline__ int get_j(const int l) {
 616            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
 617        }
 618    };
 619
 620    template <int I_, int J_>
 621    struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
 622        static constexpr int         I  = I_;
 623        static constexpr int         J  = J_;
 624        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
 625        static constexpr int         ne = I * J / (WARP_SIZE/4);
 626
 627        half2 x[ne] = {{0.0f, 0.0f}};
 628
 629        static constexpr __device__ bool supported() {
 630            if (I ==  8 && J ==  4) return true;
 631            return false;
 632        }
 633
 634        static __device__ __forceinline__ int get_i(const int l) {
 635            if constexpr (I == 8 && J == 4) {
 636                return ((l / 2) * 4) + (threadIdx.x % 4);
 637            } else {
 638                NO_DEVICE_CODE;
 639                return -1;
 640            }
 641        }
 642
 643        static __device__ __forceinline__ int get_j(const int l) {
 644            if constexpr (I == 8 && J == 4) {
 645                return ((threadIdx.x / 16) * 2) + (l % 2);
 646            } else {
 647                NO_DEVICE_CODE;
 648                return -1;
 649            }
 650        }
 651    };
 652
 653#if defined(TURING_MMA_AVAILABLE)
 654    template <int I, int J>
 655    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
 656        tile<I, J/2, half2> ret;
 657#pragma unroll
 658        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
 659            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
 660        }
 661        return ret;
 662    }
 663
 664    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
 665        tile<8, 8, half2> ret;
 666        ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
 667        ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
 668
 669        return ret;
 670    }
 671#elif defined(AMD_WMMA_AVAILABLE)
 672    template <int I, int J>
 673    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
 674        tile<I, J/2, half2> ret;
 675#pragma unroll
 676        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
 677            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
 678        }
 679        return ret;
 680    }
 681
 682    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
 683        NO_DEVICE_CODE;
 684        return tile<8, 8, half2>{};
 685    }
 686#else // Volta
 687    template <int I, int J>
 688    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
 689        tile<I, J/2, half2> ret;
 690#pragma unroll
 691        for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
 692            ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
 693            ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
 694
 695            // On Volta FP16 and FP32 tiles have a different memory layout,
 696            //     for the conversion threads with an offset of 2 need to exchange half their values:
 697            ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
 698                0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
 699        }
 700        return ret;
 701    }
 702#endif // defined(TURING_MMA_AVAILABLE)
 703
 704    static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
 705#if defined(RDNA4)
 706        const int row = t.get_i(0);
 707        const int left_right = t.get_j(0) / 4;
 708        const int up_down = row / 8;
 709        const int idx = row % 8;
 710        reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
 711#else
 712        GGML_UNUSED_VARS(t);
 713        NO_DEVICE_CODE;
 714#endif // defined(RDNA4)
 715    }
 716
 717    template <int I, int J, typename T, data_layout dl>
 718    static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 719#if defined(AMD_MFMA_AVAILABLE)
 720        if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
 721#pragma unroll
 722            for (int l = 0; l < t.ne; ++l) {
 723                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
 724            }
 725        } else {
 726            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
 727        }
 728#elif defined(AMD_WMMA_AVAILABLE)
 729        // All wmma layout has contiguous data when i-major.
 730        if constexpr (is_i_major(dl)) {
 731            // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
 732            constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
 733            if constexpr (sizeof(t.x) > aligned_copy_bytes) {
 734                static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
 735                constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
 736#pragma unroll
 737                for (int i = 0; i < aligned_copy_count; ++i) {
 738                    ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
 739                }
 740            } else {
 741                ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
 742            }
 743        } else {
 744#pragma unroll
 745            for (int l = 0; l < t.ne; ++l) {
 746                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
 747            }
 748        }
 749#else
 750#pragma unroll
 751        for (int l = 0; l < t.ne; ++l) {
 752            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
 753        }
 754#endif // defined(AMD_MFMA_AVAILABLE)
 755    }
 756
 757    template <typename T>
 758    static __device__ __forceinline__ void load_ldmatrix(
 759            tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 760#ifdef TURING_MMA_AVAILABLE
 761        int * xi = (int *) t.x;
 762        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
 763        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
 764            : "=r"(xi[0]), "=r"(xi[1])
 765            : "l"(xs));
 766#else
 767        load_generic(t, xs0, stride);
 768#endif // TURING_MMA_AVAILABLE
 769    }
 770
 771    template <typename T>
 772    static __device__ __forceinline__ void load_ldmatrix(
 773            tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
 774#ifdef TURING_MMA_AVAILABLE
 775        int * xi = (int *) t.x;
 776        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
 777        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
 778            : "=r"(xi[0]), "=r"(xi[1])
 779            : "l"(xs));
 780#else
 781#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 782        GGML_UNUSED_VARS(t, xs0, stride);
 783        NO_DEVICE_CODE;
 784#else
 785        load_generic(t, xs0, stride);
 786#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 787#endif // TURING_MMA_AVAILABLE
 788    }
 789
 790    template <typename T, data_layout dl>
 791    static __device__ __forceinline__ void load_ldmatrix(
 792            tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 793#if defined(TURING_MMA_AVAILABLE)
 794        int * xi = (int * ) t.x;
 795        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
 796        asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
 797            : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
 798            : "l"(xs));
 799#else
 800#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 801#if 1
 802        // TODO: more generic handling
 803        static_assert(sizeof(T) == 4, "bad type size");
 804        ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
 805        ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
 806#else
 807        load_generic(t, xs0, stride);
 808#endif // 1
 809#else
 810        load_generic(t, xs0, stride);
 811#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 812#endif // TURING_MMA_AVAILABLE
 813    }
 814
 815    static __device__ __forceinline__ void load_ldmatrix(
 816            tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
 817        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
 818    }
 819
 820    static __device__ __forceinline__ void load_ldmatrix(
 821            tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
 822#pragma unroll
 823        for (int l0 = 0; l0 < t.ne; l0 += 2) {
 824            ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
 825        }
 826    }
 827
 828    static __device__ __forceinline__ void load_ldmatrix(
 829            tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
 830#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 831        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
 832#else
 833        GGML_UNUSED_VARS(t, xs0, stride);
 834        NO_DEVICE_CODE;
 835#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 836    }
 837
 838    template <typename T>
 839    static __device__ __forceinline__ void load_ldmatrix_trans(
 840            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 841#ifdef TURING_MMA_AVAILABLE
 842        int * xi = (int * ) t.x;
 843        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
 844        asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
 845            : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
 846            : "l"(xs));
 847#else
 848        GGML_UNUSED_VARS(t, xs0, stride);
 849        NO_DEVICE_CODE;
 850#endif // TURING_MMA_AVAILABLE
 851    }
 852
 853    static __device__ __forceinline__ void mma(
 854            tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
 855#ifdef TURING_MMA_AVAILABLE
 856#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 857        asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
 858            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
 859            : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
 860#else
 861        // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
 862        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
 863            : "+r"(D.x[0]), "+r"(D.x[1])
 864            : "r"(A.x[0]), "r"(B.x[0]));
 865        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
 866            : "+r"(D.x[2]), "+r"(D.x[3])
 867            : "r"(A.x[1]), "r"(B.x[0]));
 868#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 869#else
 870        GGML_UNUSED_VARS(D, A, B);
 871        NO_DEVICE_CODE;
 872#endif // TURING_MMA_AVAILABLE
 873    }
 874
 875    static __device__ __forceinline__ void mma(
 876            tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
 877#ifdef TURING_MMA_AVAILABLE
 878#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 879        asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
 880            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
 881            : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
 882#else
 883        // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
 884        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
 885            : "+r"(D.x[0]), "+r"(D.x[1])
 886            : "r"(A.x[0]), "r"(B.x[0]));
 887        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
 888            : "+r"(D.x[2]), "+r"(D.x[3])
 889            : "r"(A.x[1]), "r"(B.x[0]));
 890        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
 891            : "+r"(D.x[0]), "+r"(D.x[1])
 892            : "r"(A.x[2]), "r"(B.x[1]));
 893        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
 894            : "+r"(D.x[2]), "+r"(D.x[3])
 895            : "r"(A.x[3]), "r"(B.x[1]));
 896#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 897#else
 898        GGML_UNUSED_VARS(D, A, B);
 899        NO_DEVICE_CODE;
 900#endif // TURING_MMA_AVAILABLE
 901    }
 902
 903    static __device__ __forceinline__ void mma(
 904            tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 905#ifdef TURING_MMA_AVAILABLE
 906        const int * Axi = (const int *) A.x;
 907        const int * Bxi = (const int *) B.x;
 908        int       * Dxi = (int       *) D.x;
 909#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 910        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
 911            : "+r"(Dxi[0]), "+r"(Dxi[1])
 912            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
 913#else
 914        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
 915        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
 916            : "+r"(Dxi[0]), "+r"(Dxi[1])
 917            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
 918        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
 919            : "+r"(Dxi[0]), "+r"(Dxi[1])
 920            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
 921#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 922#else
 923        GGML_UNUSED_VARS(D, A, B);
 924        NO_DEVICE_CODE;
 925#endif // TURING_MMA_AVAILABLE
 926    }
 927
 928    static __device__ __forceinline__ void mma(
 929            tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
 930#ifdef TURING_MMA_AVAILABLE
 931        const int * Axi = (const int *) A.x;
 932        const int * Bxi = (const int *) B.x;
 933        int       * Dxi = (int       *) D.x;
 934#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 935        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
 936            : "+r"(Dxi[0]), "+r"(Dxi[1])
 937            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
 938        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
 939            : "+r"(Dxi[2]), "+r"(Dxi[3])
 940            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
 941#else
 942        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
 943        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
 944            : "+r"(Dxi[0]), "+r"(Dxi[1])
 945            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
 946        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
 947            : "+r"(Dxi[0]), "+r"(Dxi[1])
 948            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
 949        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
 950            : "+r"(Dxi[2]), "+r"(Dxi[3])
 951            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
 952        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
 953            : "+r"(Dxi[2]), "+r"(Dxi[3])
 954            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
 955#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 956#elif defined(AMD_WMMA_AVAILABLE)
 957#if defined(RDNA4)
 958        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
 959        halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
 960        const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
 961        const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
 962        acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
 963#else
 964        GGML_UNUSED_VARS(D, A, B);
 965        NO_DEVICE_CODE;
 966#endif // defined(RDNA4)
 967#else
 968        GGML_UNUSED_VARS(D, A, B);
 969        NO_DEVICE_CODE;
 970#endif // TURING_MMA_AVAILABLE
 971    }
 972
 973    template <data_layout dl_ab, data_layout dl_d>
 974    static __device__ __forceinline__ void mma(
 975            tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
 976#ifdef AMPERE_MMA_AVAILABLE
 977        const int * Axi = (const int *) A.x;
 978        const int * Bxi = (const int *) B.x;
 979        int       * Dxi = (int       *) D.x;
 980        asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
 981            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
 982            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
 983#else
 984        GGML_UNUSED_VARS(D, A, B);
 985        NO_DEVICE_CODE;
 986#endif // AMPERE_MMA_AVAILABLE
 987    }
 988
 989    template <data_layout dl_ab, data_layout dl_d>
 990    static __device__ __forceinline__ void mma(
 991            tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
 992#ifdef AMD_MFMA_AVAILABLE
 993        using floatx4_t = __attribute__((ext_vector_type(4))) float;
 994        floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
 995#if defined(CDNA3)
 996        using floatx2_t = __attribute__((ext_vector_type(2))) float;
 997        const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
 998        const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
 999        acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
1000#elif defined(CDNA2) || defined(CDNA1)
1001#pragma unroll
1002        for (int i = 0; i < 2; ++i) {
1003            acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
1004        }
1005#else
1006        GGML_UNUSED_VARS(D, A, B);
1007        NO_DEVICE_CODE;
1008#endif // defined(CDNA3)
1009#else
1010        GGML_UNUSED_VARS(D, A, B);
1011        NO_DEVICE_CODE;
1012#endif // AMD_MFMA_AVAILABLE
1013    }
1014
1015    static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> &     D,
1016                                                            const tile<16, 8, int> & A,
1017                                                            const tile<8, 8, int> &  B,
1018                                                            uint32_t                 a_scale,
1019                                                            uint32_t                 b_scale) {
1020#ifdef BLACKWELL_MMA_AVAILABLE
1021        const int * Axi = (const int *) A.x;
1022        const int * Bxi = (const int *) B.x;
1023        float *     Dxi = (float *) D.x;
1024
1025        asm volatile(
1026            "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1027            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1028            "%10, {0, 0}, %11, {0, 0};"
1029            : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
1030            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
1031#else
1032        GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
1033#endif  // BLACKWELL_MMA_AVAILABLE
1034    }
1035
1036    static __device__ __forceinline__ void mma(
1037            tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
1038#ifdef TURING_MMA_AVAILABLE
1039        const int * Axi = (const int *) A.x;
1040        const int * Bxi = (const int *) B.x;
1041        int       * Dxi = (int       *) D.x;
1042#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1043        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1044            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1045            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
1046#else
1047        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
1048        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1049            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1050            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
1051        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1052            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1053            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
1054#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1055#else
1056        GGML_UNUSED_VARS(D, A, B);
1057        NO_DEVICE_CODE;
1058#endif // TURING_MMA_AVAILABLE
1059    }
1060
1061    static __device__ __forceinline__ void mma(
1062            tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
1063#ifdef AMPERE_MMA_AVAILABLE
1064        const int * Axi = (const int *) A.x;
1065        const int * Bxi = (const int *) B.x;
1066        int       * Dxi = (int       *) D.x;
1067        asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1068            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1069            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
1070#else
1071        GGML_UNUSED_VARS(D, A, B);
1072        NO_DEVICE_CODE;
1073#endif // AMPERE_MMA_AVAILABLE
1074    }
1075
1076    template <data_layout dl_ab, data_layout dl_d>
1077    static __device__ __forceinline__ void mma(
1078            tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
1079#ifdef TURING_MMA_AVAILABLE
1080        const int * Axi = (const int *) A.x;
1081        const int * Bxi = (const int *) B.x;
1082        int       * Dxi = (int       *) D.x;
1083#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1084        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1085            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1086            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
1087        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1088            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1089            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
1090#else
1091        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
1092        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1093            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1094            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
1095        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1096            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1097            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
1098        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1099            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1100            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
1101        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1102            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1103            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
1104#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1105#elif defined(AMD_WMMA_AVAILABLE)
1106#if defined(RDNA4)
1107        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
1108        using floatx8_t = __attribute__((ext_vector_type(8))) float;
1109        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1110        const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
1111        const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
1112        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
1113#elif defined(RDNA3)
1114        using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
1115        using floatx8_t = __attribute__((ext_vector_type(8))) float;
1116        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1117        const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
1118        const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
1119        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
1120#else
1121        GGML_UNUSED_VARS(D, A, B);
1122        NO_DEVICE_CODE;
1123#endif // RDNA4
1124#elif defined(AMD_MFMA_AVAILABLE)
1125        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
1126        using floatx4_t = __attribute__((ext_vector_type(4))) float;
1127        floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1128        const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
1129        const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
1130        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
1131#else
1132        GGML_UNUSED_VARS(D, A, B);
1133        NO_DEVICE_CODE;
1134#endif // TURING_MMA_AVAILABLE
1135    }
1136
1137    template <data_layout dl_ab, data_layout dl_d>
1138    static __device__ __forceinline__ void mma(
1139            tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
1140#if defined(AMD_WMMA_AVAILABLE)
1141#if defined(RDNA4)
1142        using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
1143        using floatx8_t = __attribute__((ext_vector_type(8))) float;
1144        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1145        const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
1146        const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
1147        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
1148#elif defined(RDNA3)
1149        using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
1150        using floatx8_t = __attribute__((ext_vector_type(8))) float;
1151        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1152        const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
1153        const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
1154        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
1155#else
1156        GGML_UNUSED_VARS(D, A, B);
1157        NO_DEVICE_CODE;
1158#endif // defined(RDNA4)
1159#elif defined(AMD_MFMA_AVAILABLE)
1160        using floatx4_t = __attribute__((ext_vector_type(4))) float;
1161        floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1162#if defined(CDNA3) || defined(CDNA2)
1163        using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
1164        const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
1165        const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
1166        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
1167#elif defined(CDNA1)
1168#pragma unroll
1169        for (int i = 0; i < 2; ++i) {
1170            using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
1171            const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
1172            const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
1173            acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
1174        }
1175#else
1176        GGML_UNUSED_VARS(D, A, B);
1177        NO_DEVICE_CODE;
1178#endif // defined(CDNA3) || defined(CDNA2)
1179#else
1180        GGML_UNUSED_VARS(D, A, B);
1181        NO_DEVICE_CODE;
1182#endif // defined(AMD_WMMA_AVAILABLE)
1183    }
1184
1185    template <data_layout dl_d, data_layout dl_ab>
1186    static __device__ __forceinline__ void mma(
1187            tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
1188#if defined(AMD_MFMA_AVAILABLE)
1189        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1190        int32x4_t * acc = (int32x4_t *) D.x;
1191#if defined(CDNA3)
1192        acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
1193                                                       ((int64_t *) B.x)[0],
1194                                                       acc[0],
1195                                                       0, 0, 0);
1196#elif defined(CDNA2) || defined(CDNA)
1197        acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
1198                                                      B.x[0],
1199                                                      acc[0],
1200                                                      0, 0, 0);
1201        acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
1202                                                      B.x[1],
1203                                                      acc[0],
1204                                                      0, 0, 0);
1205#endif // defined(CDNA3)
1206
1207#elif defined(AMD_WMMA_AVAILABLE)
1208
1209        using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1210        int32x8_t * acc = (int32x8_t *) D.x;
1211
1212#if defined(RDNA4)
1213        using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1214        int32x2_t * a_vec = (int32x2_t *) A.x;
1215        int32x2_t * b_vec = (int32x2_t *) B.x;
1216
1217        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1218            true,
1219            a_vec[0],
1220            true,
1221            b_vec[0],
1222            acc[0],
1223            true
1224        );
1225
1226        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1227            true,
1228            a_vec[1],
1229            true,
1230            b_vec[1],
1231            acc[0],
1232            true
1233        );
1234
1235#elif defined(RDNA3)
1236        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1237        int32x4_t * a_vec = (int32x4_t *) A.x;
1238        int32x4_t * b_vec = (int32x4_t *) B.x;
1239
1240        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1241            true,
1242            a_vec[0],
1243            true,
1244            b_vec[0],
1245            acc[0],
1246            true
1247        );
1248
1249        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1250            true,
1251            a_vec[1],
1252            true,
1253            b_vec[1],
1254            acc[0],
1255            true
1256        );
1257#endif // RDNA4
1258
1259#else
1260        GGML_UNUSED_VARS(D, A, B);
1261        NO_DEVICE_CODE;
1262#endif // AMD_MFMA_AVAILABLE
1263    }
1264
1265    static __device__ __forceinline__ void mma(
1266            tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
1267#if defined(AMD_MFMA_AVAILABLE)
1268        using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
1269        int32x16_t * acc = (int32x16_t *) D.x;
1270#if defined(CDNA3)
1271        acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
1272                                                       ((int64_t *) B.x)[0],
1273                                                       acc[0],
1274                                                       0, 0, 0);
1275#elif defined(CDNA2) || defined(CDNA)
1276        acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
1277                                                     B.x[0],
1278                                                     acc[0],
1279                                                     0, 0, 0);
1280        acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
1281                                                     B.x[1],
1282                                                     acc[0],
1283                                                     0, 0, 0);
1284#endif // defined(CDNA3)
1285
1286#else
1287        GGML_UNUSED_VARS(D, A, B);
1288        NO_DEVICE_CODE;
1289#endif // AMD_MFMA_AVAILABLE
1290    }
1291
1292    template <typename T1, typename T2, int J, int K>
1293    static __device__ __forceinline__ void mma(
1294            tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
1295        tile      <16, J, T1> * D16 = reinterpret_cast<      tile<16, J, T1> *>(&D);
1296        const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
1297        mma(D16[0], A16[0], B);
1298        mma(D16[1], A16[1], B);
1299    }
1300
1301    static __device__ __forceinline__ void mma(
1302            tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
1303#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1304        const int * Axi = (const int *) A.x;
1305        const int * Bxi = (const int *) B.x;
1306        int       * Dxi = (int       *) D.x;
1307        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1308            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1309            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1310            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1311        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1312            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1313            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1314            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1315#else
1316        GGML_UNUSED_VARS(D, A, B);
1317        NO_DEVICE_CODE;
1318#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1319    }
1320
1321    static __device__ __forceinline__ void mma(
1322            tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
1323#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1324        const int * Axi = (const int *) A.x;
1325        const int * Bxi = (const int *) B.x;
1326        int       * Dxi = (int       *) D.x;
1327        asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1328            "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1329            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1330            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1331        asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1332            "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1333            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1334            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1335#else
1336        GGML_UNUSED_VARS(D, A, B);
1337        NO_DEVICE_CODE;
1338#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1339    }
1340
1341    template <data_layout dl_d, data_layout dl_ab>
1342    static __device__ __forceinline__ void mma(
1343            tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
1344#if defined(AMD_WMMA_AVAILABLE)
1345        using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1346        int32x8_t * acc = (int32x8_t *) D.x;
1347#if defined(RDNA4)
1348        using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1349        int32x2_t * a_vec = (int32x2_t *) A.x;
1350        int32x2_t * b_vec = (int32x2_t *) B.x;
1351
1352        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1353            true,
1354            a_vec[0],
1355            true,
1356            b_vec[0],
1357            acc[0],
1358            false
1359        );
1360#elif defined(RDNA3)
1361        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1362        int32x4_t * a_vec = (int32x4_t *) A.x;
1363        int32x4_t * b_vec = (int32x4_t *) B.x;
1364
1365        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1366            true,
1367            a_vec[0],
1368            true,
1369            b_vec[0],
1370            acc[0],
1371            false
1372        );
1373#endif // RDNA4
1374#else
1375        GGML_UNUSED(D);
1376        GGML_UNUSED(A);
1377        GGML_UNUSED(B);
1378        NO_DEVICE_CODE;
1379#endif // AMD_WMMA_AVAILABLE
1380    }
1381}