1#pragma once
2// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
3// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
4// The documentation for the PTX instructions can be found under:
5// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
6//
7// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
8// A is a row-major matrix with shape M x K.
9// B is a column-major matrix with shape K x N.
10// C is a column-major matrix with shape M x N.
11// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
12// Note that J is measured in physical 32 bit elements instead of logical elements.
13// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
14// All matrix tiles have ne physical 32 bit elements per warp.
15//
16// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
17// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
18
19#include "common.cuh"
20
21// On Volta each warp is doing 4 8x8 mma operations in parallel.
22// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23// However, the i indices in this file are by default permuted to simplify the index calculations.
24// #define GGML_CUDA_MMA_NO_VOLTA_PERM
25
26#if CUDART_VERSION >= 11080
27
28static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
29 int ret = 0;
30
31#ifdef TURING_MMA_AVAILABLE
32 asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
33 : "=r"(ret) : "r"(x));
34#else
35 GGML_UNUSED(x);
36 NO_DEVICE_CODE;
37#endif // defined(TURING_MMA_AVAILABLE)
38 return ret;
39}
40
41#else
42
43static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
44 // Imagine transposing row-major matrix to column-major matrix.
45 const int src_i_low = 2 * (threadIdx.x % 4);
46 const int src_i_high = src_i_low + 1;
47 const int src_j = threadIdx.x / 4;
48
49 const int src_laneid_low = src_i_low * 4 + src_j / 2;
50 const int src_laneid_high = src_i_high * 4 + src_j / 2;
51
52 const int shift_low = ((src_j + 0) % 2) * 16;
53 const int shift_high = ((src_j + 1) % 2) * 16;
54
55 const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
56 const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
57
58 return ret_low | ret_high;
59}
60
61#endif // CUDART_VERSION >= 11080
62
63static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
64 half2 ret;
65 *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
66 return ret;
67}
68
69namespace ggml_cuda_mma {
70
71 // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
72 // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
73 // In those cases the data can be split in different ways across the warp.
74 enum data_layout {
75 // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
76 // For the A/C matrices this means I major == row major, J major == column major.
77 // For the B matrix this means I major == column major, J major == row major.
78 // MIRRORED == Each data value is held exactly once per thread subgroup.
79 DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
80 DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81 DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
82 DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
83 };
84 // Implemented mma combinations are:
85 // - (I_MAJOR, I_MAJOR) -> I_MAJOR
86 // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
87 // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
88
89 static constexpr bool is_i_major(const data_layout dl) {
90 return dl == DATA_LAYOUT_I_MAJOR ||
91 dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
92 }
93
94 static constexpr __device__ data_layout get_input_data_layout() {
95#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
96 return DATA_LAYOUT_I_MAJOR_MIRRORED;
97#else
98 return DATA_LAYOUT_I_MAJOR;
99#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
100 }
101
102 template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
103 struct tile {};
104
105 template <int I_, int J_, typename T>
106 struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
107 static constexpr int I = I_;
108 static constexpr int J = J_;
109 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
110
111#if defined(AMD_MFMA_AVAILABLE)
112 static constexpr int ne = I * J / 64;
113 T x[ne] = {0};
114
115 static constexpr __device__ bool supported() {
116 if (I == 64 && J == 2) return true;
117 if (I == 16 && J == 8) return true;
118 if (I == 32 && J == 4) return true;
119 if (I == 16 && J == 16) return true;
120 if (I == 32 && J == 32) return true;
121 return false;
122 }
123
124 static __device__ __forceinline__ int get_i(const int l) {
125 if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
126 return threadIdx.x % 16;
127 } else if constexpr (I == 16 && J == 8) {
128 return threadIdx.x % 16;
129 } else if constexpr (I == 32 && J == 4) {
130 return threadIdx.x % 32;
131 } else if constexpr (I == 16 && J == 16) {
132 return threadIdx.x % 16;
133 } else if constexpr (I == 32 && J == 32) {
134 return threadIdx.x % 32;
135 } else {
136 NO_DEVICE_CODE;
137 return -1;
138 }
139 }
140
141 static __device__ __forceinline__ int get_j(const int l) {
142 if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
143 return (2 * ((threadIdx.x / 16) % 2) + l);
144 } else if constexpr (I == 16 && J == 8) {
145 return 2 * (threadIdx.x / 16) + l;
146 } else if constexpr (I == 32 && J == 4) {
147 return 2 * (threadIdx.x / 32) + l;
148 } else if constexpr (I == 16 && J == 16) {
149 return 4 * (threadIdx.x / 16) + l;
150 } else if constexpr (I == 32 && J == 32) {
151 return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
152 } else {
153 NO_DEVICE_CODE;
154 return -1;
155 }
156 }
157#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
158 static constexpr int ne = I * J / 32;
159 T x[ne] = {0};
160
161 static constexpr __device__ bool supported() {
162 if (I == 32 && J == 8) return true;
163 return false;
164 }
165
166 static __device__ __forceinline__ int get_i(const int l) {
167 if constexpr (I == 32 && J == 8) {
168#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
169 return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
170#else
171 return (l & 2) + (threadIdx.x & ~2);
172#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
173 } else {
174 NO_DEVICE_CODE;
175 return -1;
176 }
177 }
178
179 static __device__ __forceinline__ int get_j(const int l) {
180 if constexpr (I == 32 && J == 8) {
181 return (threadIdx.x & 2) + (l & (4 + 1));
182 } else {
183 NO_DEVICE_CODE;
184 return -1;
185 }
186 }
187#elif defined(AMD_WMMA_AVAILABLE)
188 static constexpr int ne = I * J / 32;
189 T x[ne] = {0};
190
191 static constexpr __device__ bool supported() {
192 if (I == 16 && J == 16) return true;
193 if (I == 16 && J == 8) return true;
194 if (I == 16 && J == 4) return true;
195 return false;
196 }
197
198 static __device__ __forceinline__ int get_i(const int l) {
199 if constexpr (supported()) {
200 return threadIdx.x % 16;
201 } else {
202 NO_DEVICE_CODE;
203 return -1;
204 }
205 }
206
207 static __device__ __forceinline__ int get_j(const int l) {
208 if constexpr (I == 16 && J == 16) {
209#if defined(RDNA3)
210 if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
211 // matrix C
212 return 2 * l + (threadIdx.x / 16);
213 } else {
214 // matrix A&B
215 return l;
216 }
217#else
218 // matrix C is the transposed matrix A&B on RDNA4
219 return ne * (threadIdx.x / 16) + l;
220#endif // defined(RDNA3)
221 } else if constexpr (I == 16 && J == 8) {
222 // mmq input for RDNA4
223 return ne * (threadIdx.x / 16) + l;
224 } else if constexpr (I == 16 && J == 4) {
225 return ne * (threadIdx.x / 16) + l;
226 } else {
227 NO_DEVICE_CODE;
228 return -1;
229 }
230 }
231#else
232 static constexpr int ne = I * J / 32;
233 T x[ne] = {0};
234
235 static constexpr __device__ bool supported() {
236 if (I == 8 && J == 4) return true;
237 if (I == 8 && J == 8) return true;
238 if (I == 16 && J == 8) return true;
239 if (I == 16 && J == 16) return true;
240 if (I == 32 && J == 8) return true;
241 return false;
242 }
243
244 static __device__ __forceinline__ int get_i(const int l) {
245 if constexpr (I == 8 && J == 4) {
246 return threadIdx.x / 4;
247 } else if constexpr (I == 8 && J == 8) {
248 return threadIdx.x / 4;
249 } else if constexpr (I == 16 && J == 8) {
250 return ((l / 2) * 8) + (threadIdx.x / 4);
251 } else if constexpr (I == 16 && J == 16) {
252 return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
253 } else if constexpr (I == 32 && J == 8) {
254 return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
255 } else {
256 NO_DEVICE_CODE;
257 return -1;
258 }
259 }
260
261 static __device__ __forceinline__ int get_j(const int l) {
262 if constexpr (I == 8 && J == 4) {
263 return threadIdx.x % 4;
264 } else if constexpr (I == 8 && J == 8) {
265 return (l * 4) + (threadIdx.x % 4);
266 } else if constexpr (I == 16 && J == 8) {
267 return ((threadIdx.x % 4) * 2) + (l % 2);
268 } else if constexpr (I == 16 && J == 16) {
269 return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
270 } else if constexpr (I == 32 && J == 8) {
271 return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
272 } else {
273 NO_DEVICE_CODE;
274 return -1;
275 }
276 }
277#endif // defined(GGML_USE_HIP)
278 };
279
280 template <int I_, int J_>
281 struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
282 static constexpr int I = I_;
283 static constexpr int J = J_;
284 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
285
286#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
287 static constexpr int ne = I * J / WARP_SIZE;
288 half2 x[ne] = {{0.0f, 0.0f}};
289
290 static constexpr __device__ bool supported() {
291 if (I == 32 && J == 4) return true;
292 return false;
293 }
294
295 static __device__ __forceinline__ int get_i(const int l) {
296 if constexpr (I == 32 && J == 4) {
297#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
298 return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
299#else
300 return threadIdx.x;
301#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
302 } else {
303 NO_DEVICE_CODE;
304 return -1;
305 }
306 }
307
308 static __device__ __forceinline__ int get_j(const int l) {
309 if constexpr (I == 32 && J == 4) {
310 return l;
311 } else {
312 NO_DEVICE_CODE;
313 return -1;
314 }
315 }
316#elif defined(AMD_WMMA_AVAILABLE)
317 static constexpr int ne = I * J / 32;
318 half2 x[ne] = {{0.0f, 0.0f}};
319
320 static constexpr __device__ bool supported() {
321 if (I == 16 && J == 8) return true;
322 return false;
323 }
324
325 static __device__ __forceinline__ int get_i(const int l) {
326 if constexpr (I == 16 && J == 8) {
327 return threadIdx.x % 16;
328 } else {
329 NO_DEVICE_CODE;
330 return -1;
331 }
332 }
333
334 static __device__ __forceinline__ int get_j(const int l) {
335 if constexpr (I == 16 && J == 8) {
336 return ne * (threadIdx.x / 16) + l;
337 } else {
338 NO_DEVICE_CODE;
339 return -1;
340 }
341 }
342#elif defined(AMD_MFMA_AVAILABLE)
343 static constexpr int ne = I * J / 64;
344 half2 x[ne] = {{0.0f, 0.0f}};
345
346 static constexpr __device__ bool supported() {
347 if (I == 16 && J == 8) return true;
348 return false;
349 }
350
351 static __device__ __forceinline__ int get_i(const int l) {
352 if constexpr (I == 16 && J == 8) {
353 return threadIdx.x % 16;
354 } else {
355 NO_DEVICE_CODE;
356 return -1;
357 }
358 }
359
360 static __device__ __forceinline__ int get_j(const int l) {
361 if constexpr (I == 16 && J == 8) {
362 return ne * (threadIdx.x / 16) + l;
363 } else {
364 NO_DEVICE_CODE;
365 return -1;
366 }
367 }
368#else
369 static constexpr int ne = I * J / WARP_SIZE;
370 half2 x[ne] = {{0.0f, 0.0f}};
371
372 static constexpr __device__ bool supported() {
373 if (I == 8 && J == 4) return true;
374 if (I == 8 && J == 8) return true;
375 if (I == 16 && J == 8) return true;
376 if (I == 16 && J == 16) return true;
377 if (I == 32 && J == 8) return true;
378 return false;
379 }
380
381 static __device__ __forceinline__ int get_i(const int l) {
382 if constexpr (I == 8 && J == 8) {
383 return threadIdx.x / 4;
384 } else if constexpr (I == 16 && J == 4) {
385 return (l * 8) + (threadIdx.x / 4);
386 } else if constexpr (I == 16 && J == 8) {
387 return ((l % 2) * 8) + (threadIdx.x / 4);
388 } else if constexpr (I == 32 && J == 8) {
389 return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
390 } else {
391 NO_DEVICE_CODE;
392 return -1;
393 }
394 }
395
396 static __device__ __forceinline__ int get_j(const int l) {
397 if constexpr (I == 8 && J == 8) {
398 return (l * 4) + (threadIdx.x % 4);
399 } else if constexpr (I == 16 && J == 4) {
400 return threadIdx.x % 4;
401 } else if constexpr (I == 16 && J == 8) {
402 return ((l / 2) * 4) + (threadIdx.x % 4);
403 } else if constexpr (I == 32 && J == 8) {
404 return ((l & 2) * 2) + (threadIdx.x % 4);
405 } else {
406 NO_DEVICE_CODE;
407 return -1;
408 }
409 }
410#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
411 };
412
413 template <int I_, int J_>
414 struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
415 static constexpr int I = I_;
416 static constexpr int J = J_;
417 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
418
419#if defined(AMD_WMMA_AVAILABLE)
420 static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
421 nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
422
423 static constexpr __device__ bool supported() {
424 return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
425 }
426
427 static __device__ __forceinline__ int get_i(const int l) {
428 return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
429 }
430
431 static __device__ __forceinline__ int get_j(const int l) {
432 return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
433 }
434#elif defined(AMD_MFMA_AVAILABLE)
435 static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
436 nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
437
438 static constexpr __device__ bool supported() {
439 return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
440 }
441
442 static __device__ __forceinline__ int get_i(const int l) {
443 return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
444 }
445
446 static __device__ __forceinline__ int get_j(const int l) {
447 return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
448 }
449#else
450 static constexpr int ne = I * J / WARP_SIZE;
451 nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
452
453 static constexpr __device__ bool supported() {
454 if (I == 8 && J == 8) return true;
455 if (I == 16 && J == 4) return true;
456 if (I == 16 && J == 8) return true;
457 return false;
458 }
459
460 static __device__ __forceinline__ int get_i(const int l) {
461 if constexpr (I == 8 && J == 8) {
462 return threadIdx.x / 4;
463 } else if constexpr (I == 16 && J == 4) {
464 return (l * 8) + (threadIdx.x / 4);
465 } else if constexpr (I == 16 && J == 8) {
466 return ((l % 2) * 8) + (threadIdx.x / 4);
467 } else {
468 NO_DEVICE_CODE;
469 return -1;
470 }
471 }
472
473 static __device__ __forceinline__ int get_j(const int l) {
474 if constexpr (I == 8 && J == 8) {
475 return (l * 4) + (threadIdx.x % 4);
476 } else if constexpr (I == 16 && J == 4) {
477 return threadIdx.x % 4;
478 } else if constexpr (I == 16 && J == 8) {
479 return ((l / 2) * 4) + (threadIdx.x % 4);
480 } else {
481 NO_DEVICE_CODE;
482 return -1;
483 }
484 }
485#endif // defined(AMD_WMMA_AVAILABLE)
486 };
487
488 template <int I_, int J_, typename T>
489 struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
490 static constexpr int I = I_;
491 static constexpr int J = J_;
492 static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
493
494 static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
495 T x[ne] = {0};
496
497 static constexpr __device__ bool supported() {
498 return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
499 }
500
501 static __device__ __forceinline__ int get_i(const int l) {
502 return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
503 }
504
505 static __device__ __forceinline__ int get_j(const int l) {
506 return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
507 }
508 };
509
510 template <int I_, int J_, typename T>
511 struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
512 static constexpr int I = I_;
513 static constexpr int J = J_;
514 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
515
516 // RDNA3
517 static constexpr int ne = I * J / 32 * 2;
518
519 T x[ne] = {0};
520
521 static constexpr __device__ bool supported() {
522 if (I == 16 && J == 16) return true;
523 if (I == 16 && J == 8) return true;
524 if (I == 16 && J == 4) return true;
525 return false;
526 }
527
528 static __device__ __forceinline__ int get_i(const int /*l*/) {
529 if constexpr (supported()) {
530 return threadIdx.x % 16;
531 } else {
532 NO_DEVICE_CODE;
533 return -1;
534 }
535 }
536
537 static __device__ __forceinline__ int get_j(const int l) {
538 if constexpr (supported()) {
539 return l;
540 } else {
541 NO_DEVICE_CODE;
542 return -1;
543 }
544 }
545 };
546
547 template <int I_, int J_>
548 struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
549 static constexpr int I = I_;
550 static constexpr int J = J_;
551 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
552#if defined(RDNA3)
553 static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
554
555 half2 x[ne] = {{0.0f, 0.0f}};
556
557 static constexpr __device__ bool supported() {
558 return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
559 }
560
561 static __device__ __forceinline__ int get_i(const int l) {
562 return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
563 }
564
565 static __device__ __forceinline__ int get_j(const int l) {
566 return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
567 }
568#else // Volta
569 static constexpr int ne = I * J / (WARP_SIZE/4);
570
571 half2 x[ne] = {{0.0f, 0.0f}};
572
573 static constexpr __device__ bool supported() {
574 if (I == 8 && J == 4) return true;
575 return false;
576 }
577
578 static __device__ __forceinline__ int get_i(const int /*l*/) {
579 if constexpr (I == 8 && J == 4) {
580 return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
581 } else {
582 NO_DEVICE_CODE;
583 return -1;
584 }
585 }
586
587 static __device__ __forceinline__ int get_j(const int l) {
588 if constexpr (I == 8 && J == 4) {
589 return l;
590 } else {
591 NO_DEVICE_CODE;
592 return -1;
593 }
594 }
595#endif // defined(RDNA3)
596 };
597
598 template <int I_, int J_>
599 struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
600 static constexpr int I = I_;
601 static constexpr int J = J_;
602 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
603 static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
604
605 nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
606
607 static constexpr __device__ bool supported() {
608 return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
609 }
610
611 static __device__ __forceinline__ int get_i(const int l) {
612 return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
613 }
614
615 static __device__ __forceinline__ int get_j(const int l) {
616 return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
617 }
618 };
619
620 template <int I_, int J_>
621 struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
622 static constexpr int I = I_;
623 static constexpr int J = J_;
624 static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
625 static constexpr int ne = I * J / (WARP_SIZE/4);
626
627 half2 x[ne] = {{0.0f, 0.0f}};
628
629 static constexpr __device__ bool supported() {
630 if (I == 8 && J == 4) return true;
631 return false;
632 }
633
634 static __device__ __forceinline__ int get_i(const int l) {
635 if constexpr (I == 8 && J == 4) {
636 return ((l / 2) * 4) + (threadIdx.x % 4);
637 } else {
638 NO_DEVICE_CODE;
639 return -1;
640 }
641 }
642
643 static __device__ __forceinline__ int get_j(const int l) {
644 if constexpr (I == 8 && J == 4) {
645 return ((threadIdx.x / 16) * 2) + (l % 2);
646 } else {
647 NO_DEVICE_CODE;
648 return -1;
649 }
650 }
651 };
652
653#if defined(TURING_MMA_AVAILABLE)
654 template <int I, int J>
655 static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
656 tile<I, J/2, half2> ret;
657#pragma unroll
658 for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
659 ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
660 }
661 return ret;
662 }
663
664 static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
665 tile<8, 8, half2> ret;
666 ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
667 ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
668
669 return ret;
670 }
671#elif defined(AMD_WMMA_AVAILABLE)
672 template <int I, int J>
673 static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
674 tile<I, J/2, half2> ret;
675#pragma unroll
676 for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
677 ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
678 }
679 return ret;
680 }
681
682 static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
683 NO_DEVICE_CODE;
684 return tile<8, 8, half2>{};
685 }
686#else // Volta
687 template <int I, int J>
688 static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
689 tile<I, J/2, half2> ret;
690#pragma unroll
691 for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
692 ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
693 ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
694
695 // On Volta FP16 and FP32 tiles have a different memory layout,
696 // for the conversion threads with an offset of 2 need to exchange half their values:
697 ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
698 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
699 }
700 return ret;
701 }
702#endif // defined(TURING_MMA_AVAILABLE)
703
704 static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
705#if defined(RDNA4)
706 const int row = t.get_i(0);
707 const int left_right = t.get_j(0) / 4;
708 const int up_down = row / 8;
709 const int idx = row % 8;
710 reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
711#else
712 GGML_UNUSED_VARS(t);
713 NO_DEVICE_CODE;
714#endif // defined(RDNA4)
715 }
716
717 template <int I, int J, typename T, data_layout dl>
718 static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
719#if defined(AMD_MFMA_AVAILABLE)
720 if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
721#pragma unroll
722 for (int l = 0; l < t.ne; ++l) {
723 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
724 }
725 } else {
726 ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
727 }
728#elif defined(AMD_WMMA_AVAILABLE)
729 // All wmma layout has contiguous data when i-major.
730 if constexpr (is_i_major(dl)) {
731 // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
732 constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
733 if constexpr (sizeof(t.x) > aligned_copy_bytes) {
734 static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
735 constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
736#pragma unroll
737 for (int i = 0; i < aligned_copy_count; ++i) {
738 ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
739 }
740 } else {
741 ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
742 }
743 } else {
744#pragma unroll
745 for (int l = 0; l < t.ne; ++l) {
746 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
747 }
748 }
749#else
750#pragma unroll
751 for (int l = 0; l < t.ne; ++l) {
752 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
753 }
754#endif // defined(AMD_MFMA_AVAILABLE)
755 }
756
757 template <typename T>
758 static __device__ __forceinline__ void load_ldmatrix(
759 tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
760#ifdef TURING_MMA_AVAILABLE
761 int * xi = (int *) t.x;
762 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
763 asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
764 : "=r"(xi[0]), "=r"(xi[1])
765 : "l"(xs));
766#else
767 load_generic(t, xs0, stride);
768#endif // TURING_MMA_AVAILABLE
769 }
770
771 template <typename T>
772 static __device__ __forceinline__ void load_ldmatrix(
773 tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
774#ifdef TURING_MMA_AVAILABLE
775 int * xi = (int *) t.x;
776 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
777 asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
778 : "=r"(xi[0]), "=r"(xi[1])
779 : "l"(xs));
780#else
781#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
782 GGML_UNUSED_VARS(t, xs0, stride);
783 NO_DEVICE_CODE;
784#else
785 load_generic(t, xs0, stride);
786#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
787#endif // TURING_MMA_AVAILABLE
788 }
789
790 template <typename T, data_layout dl>
791 static __device__ __forceinline__ void load_ldmatrix(
792 tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
793#if defined(TURING_MMA_AVAILABLE)
794 int * xi = (int * ) t.x;
795 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
796 asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
797 : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
798 : "l"(xs));
799#else
800#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
801#if 1
802 // TODO: more generic handling
803 static_assert(sizeof(T) == 4, "bad type size");
804 ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
805 ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
806#else
807 load_generic(t, xs0, stride);
808#endif // 1
809#else
810 load_generic(t, xs0, stride);
811#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
812#endif // TURING_MMA_AVAILABLE
813 }
814
815 static __device__ __forceinline__ void load_ldmatrix(
816 tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
817 ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
818 }
819
820 static __device__ __forceinline__ void load_ldmatrix(
821 tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
822#pragma unroll
823 for (int l0 = 0; l0 < t.ne; l0 += 2) {
824 ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
825 }
826 }
827
828 static __device__ __forceinline__ void load_ldmatrix(
829 tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
830#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
831 ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
832#else
833 GGML_UNUSED_VARS(t, xs0, stride);
834 NO_DEVICE_CODE;
835#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
836 }
837
838 template <typename T>
839 static __device__ __forceinline__ void load_ldmatrix_trans(
840 tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
841#ifdef TURING_MMA_AVAILABLE
842 int * xi = (int * ) t.x;
843 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
844 asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
845 : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
846 : "l"(xs));
847#else
848 GGML_UNUSED_VARS(t, xs0, stride);
849 NO_DEVICE_CODE;
850#endif // TURING_MMA_AVAILABLE
851 }
852
853 static __device__ __forceinline__ void mma(
854 tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
855#ifdef TURING_MMA_AVAILABLE
856#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
857 asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
858 : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
859 : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
860#else
861 // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
862 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
863 : "+r"(D.x[0]), "+r"(D.x[1])
864 : "r"(A.x[0]), "r"(B.x[0]));
865 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
866 : "+r"(D.x[2]), "+r"(D.x[3])
867 : "r"(A.x[1]), "r"(B.x[0]));
868#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
869#else
870 GGML_UNUSED_VARS(D, A, B);
871 NO_DEVICE_CODE;
872#endif // TURING_MMA_AVAILABLE
873 }
874
875 static __device__ __forceinline__ void mma(
876 tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
877#ifdef TURING_MMA_AVAILABLE
878#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
879 asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
880 : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
881 : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
882#else
883 // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
884 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
885 : "+r"(D.x[0]), "+r"(D.x[1])
886 : "r"(A.x[0]), "r"(B.x[0]));
887 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
888 : "+r"(D.x[2]), "+r"(D.x[3])
889 : "r"(A.x[1]), "r"(B.x[0]));
890 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
891 : "+r"(D.x[0]), "+r"(D.x[1])
892 : "r"(A.x[2]), "r"(B.x[1]));
893 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
894 : "+r"(D.x[2]), "+r"(D.x[3])
895 : "r"(A.x[3]), "r"(B.x[1]));
896#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
897#else
898 GGML_UNUSED_VARS(D, A, B);
899 NO_DEVICE_CODE;
900#endif // TURING_MMA_AVAILABLE
901 }
902
903 static __device__ __forceinline__ void mma(
904 tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
905#ifdef TURING_MMA_AVAILABLE
906 const int * Axi = (const int *) A.x;
907 const int * Bxi = (const int *) B.x;
908 int * Dxi = (int *) D.x;
909#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
910 asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
911 : "+r"(Dxi[0]), "+r"(Dxi[1])
912 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
913#else
914 // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
915 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
916 : "+r"(Dxi[0]), "+r"(Dxi[1])
917 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
918 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
919 : "+r"(Dxi[0]), "+r"(Dxi[1])
920 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
921#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
922#else
923 GGML_UNUSED_VARS(D, A, B);
924 NO_DEVICE_CODE;
925#endif // TURING_MMA_AVAILABLE
926 }
927
928 static __device__ __forceinline__ void mma(
929 tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
930#ifdef TURING_MMA_AVAILABLE
931 const int * Axi = (const int *) A.x;
932 const int * Bxi = (const int *) B.x;
933 int * Dxi = (int *) D.x;
934#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
935 asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
936 : "+r"(Dxi[0]), "+r"(Dxi[1])
937 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
938 asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
939 : "+r"(Dxi[2]), "+r"(Dxi[3])
940 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
941#else
942 // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
943 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
944 : "+r"(Dxi[0]), "+r"(Dxi[1])
945 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
946 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
947 : "+r"(Dxi[0]), "+r"(Dxi[1])
948 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
949 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
950 : "+r"(Dxi[2]), "+r"(Dxi[3])
951 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
952 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
953 : "+r"(Dxi[2]), "+r"(Dxi[3])
954 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
955#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
956#elif defined(AMD_WMMA_AVAILABLE)
957#if defined(RDNA4)
958 using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
959 halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
960 const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
961 const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
962 acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
963#else
964 GGML_UNUSED_VARS(D, A, B);
965 NO_DEVICE_CODE;
966#endif // defined(RDNA4)
967#else
968 GGML_UNUSED_VARS(D, A, B);
969 NO_DEVICE_CODE;
970#endif // TURING_MMA_AVAILABLE
971 }
972
973 template <data_layout dl_ab, data_layout dl_d>
974 static __device__ __forceinline__ void mma(
975 tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
976#ifdef AMPERE_MMA_AVAILABLE
977 const int * Axi = (const int *) A.x;
978 const int * Bxi = (const int *) B.x;
979 int * Dxi = (int *) D.x;
980 asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
981 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
982 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
983#else
984 GGML_UNUSED_VARS(D, A, B);
985 NO_DEVICE_CODE;
986#endif // AMPERE_MMA_AVAILABLE
987 }
988
989 template <data_layout dl_ab, data_layout dl_d>
990 static __device__ __forceinline__ void mma(
991 tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
992#ifdef AMD_MFMA_AVAILABLE
993 using floatx4_t = __attribute__((ext_vector_type(4))) float;
994 floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
995#if defined(CDNA3)
996 using floatx2_t = __attribute__((ext_vector_type(2))) float;
997 const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
998 const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
999 acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
1000#elif defined(CDNA2) || defined(CDNA1)
1001#pragma unroll
1002 for (int i = 0; i < 2; ++i) {
1003 acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
1004 }
1005#else
1006 GGML_UNUSED_VARS(D, A, B);
1007 NO_DEVICE_CODE;
1008#endif // defined(CDNA3)
1009#else
1010 GGML_UNUSED_VARS(D, A, B);
1011 NO_DEVICE_CODE;
1012#endif // AMD_MFMA_AVAILABLE
1013 }
1014
1015 static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
1016 const tile<16, 8, int> & A,
1017 const tile<8, 8, int> & B,
1018 uint32_t a_scale,
1019 uint32_t b_scale) {
1020#ifdef BLACKWELL_MMA_AVAILABLE
1021 const int * Axi = (const int *) A.x;
1022 const int * Bxi = (const int *) B.x;
1023 float * Dxi = (float *) D.x;
1024
1025 asm volatile(
1026 "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1027 "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1028 "%10, {0, 0}, %11, {0, 0};"
1029 : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
1030 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
1031#else
1032 GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
1033#endif // BLACKWELL_MMA_AVAILABLE
1034 }
1035
1036 static __device__ __forceinline__ void mma(
1037 tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
1038#ifdef TURING_MMA_AVAILABLE
1039 const int * Axi = (const int *) A.x;
1040 const int * Bxi = (const int *) B.x;
1041 int * Dxi = (int *) D.x;
1042#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1043 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1044 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1045 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
1046#else
1047 // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
1048 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1049 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1050 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
1051 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1052 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1053 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
1054#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1055#else
1056 GGML_UNUSED_VARS(D, A, B);
1057 NO_DEVICE_CODE;
1058#endif // TURING_MMA_AVAILABLE
1059 }
1060
1061 static __device__ __forceinline__ void mma(
1062 tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
1063#ifdef AMPERE_MMA_AVAILABLE
1064 const int * Axi = (const int *) A.x;
1065 const int * Bxi = (const int *) B.x;
1066 int * Dxi = (int *) D.x;
1067 asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1068 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1069 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
1070#else
1071 GGML_UNUSED_VARS(D, A, B);
1072 NO_DEVICE_CODE;
1073#endif // AMPERE_MMA_AVAILABLE
1074 }
1075
1076 template <data_layout dl_ab, data_layout dl_d>
1077 static __device__ __forceinline__ void mma(
1078 tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
1079#ifdef TURING_MMA_AVAILABLE
1080 const int * Axi = (const int *) A.x;
1081 const int * Bxi = (const int *) B.x;
1082 int * Dxi = (int *) D.x;
1083#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1084 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1085 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1086 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
1087 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1088 : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1089 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
1090#else
1091 // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
1092 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1093 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1094 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
1095 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1096 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1097 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
1098 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1099 : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1100 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
1101 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1102 : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1103 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
1104#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1105#elif defined(AMD_WMMA_AVAILABLE)
1106#if defined(RDNA4)
1107 using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
1108 using floatx8_t = __attribute__((ext_vector_type(8))) float;
1109 floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1110 const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
1111 const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
1112 acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
1113#elif defined(RDNA3)
1114 using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
1115 using floatx8_t = __attribute__((ext_vector_type(8))) float;
1116 floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1117 const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
1118 const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
1119 acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
1120#else
1121 GGML_UNUSED_VARS(D, A, B);
1122 NO_DEVICE_CODE;
1123#endif // RDNA4
1124#elif defined(AMD_MFMA_AVAILABLE)
1125 using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
1126 using floatx4_t = __attribute__((ext_vector_type(4))) float;
1127 floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1128 const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
1129 const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
1130 acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
1131#else
1132 GGML_UNUSED_VARS(D, A, B);
1133 NO_DEVICE_CODE;
1134#endif // TURING_MMA_AVAILABLE
1135 }
1136
1137 template <data_layout dl_ab, data_layout dl_d>
1138 static __device__ __forceinline__ void mma(
1139 tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
1140#if defined(AMD_WMMA_AVAILABLE)
1141#if defined(RDNA4)
1142 using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
1143 using floatx8_t = __attribute__((ext_vector_type(8))) float;
1144 floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1145 const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
1146 const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
1147 acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
1148#elif defined(RDNA3)
1149 using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
1150 using floatx8_t = __attribute__((ext_vector_type(8))) float;
1151 floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1152 const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
1153 const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
1154 acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
1155#else
1156 GGML_UNUSED_VARS(D, A, B);
1157 NO_DEVICE_CODE;
1158#endif // defined(RDNA4)
1159#elif defined(AMD_MFMA_AVAILABLE)
1160 using floatx4_t = __attribute__((ext_vector_type(4))) float;
1161 floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1162#if defined(CDNA3) || defined(CDNA2)
1163 using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
1164 const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
1165 const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
1166 acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
1167#elif defined(CDNA1)
1168#pragma unroll
1169 for (int i = 0; i < 2; ++i) {
1170 using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
1171 const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
1172 const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
1173 acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
1174 }
1175#else
1176 GGML_UNUSED_VARS(D, A, B);
1177 NO_DEVICE_CODE;
1178#endif // defined(CDNA3) || defined(CDNA2)
1179#else
1180 GGML_UNUSED_VARS(D, A, B);
1181 NO_DEVICE_CODE;
1182#endif // defined(AMD_WMMA_AVAILABLE)
1183 }
1184
1185 template <data_layout dl_d, data_layout dl_ab>
1186 static __device__ __forceinline__ void mma(
1187 tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
1188#if defined(AMD_MFMA_AVAILABLE)
1189 using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1190 int32x4_t * acc = (int32x4_t *) D.x;
1191#if defined(CDNA3)
1192 acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
1193 ((int64_t *) B.x)[0],
1194 acc[0],
1195 0, 0, 0);
1196#elif defined(CDNA2) || defined(CDNA)
1197 acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
1198 B.x[0],
1199 acc[0],
1200 0, 0, 0);
1201 acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
1202 B.x[1],
1203 acc[0],
1204 0, 0, 0);
1205#endif // defined(CDNA3)
1206
1207#elif defined(AMD_WMMA_AVAILABLE)
1208
1209 using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1210 int32x8_t * acc = (int32x8_t *) D.x;
1211
1212#if defined(RDNA4)
1213 using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1214 int32x2_t * a_vec = (int32x2_t *) A.x;
1215 int32x2_t * b_vec = (int32x2_t *) B.x;
1216
1217 acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1218 true,
1219 a_vec[0],
1220 true,
1221 b_vec[0],
1222 acc[0],
1223 true
1224 );
1225
1226 acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1227 true,
1228 a_vec[1],
1229 true,
1230 b_vec[1],
1231 acc[0],
1232 true
1233 );
1234
1235#elif defined(RDNA3)
1236 using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1237 int32x4_t * a_vec = (int32x4_t *) A.x;
1238 int32x4_t * b_vec = (int32x4_t *) B.x;
1239
1240 acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1241 true,
1242 a_vec[0],
1243 true,
1244 b_vec[0],
1245 acc[0],
1246 true
1247 );
1248
1249 acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1250 true,
1251 a_vec[1],
1252 true,
1253 b_vec[1],
1254 acc[0],
1255 true
1256 );
1257#endif // RDNA4
1258
1259#else
1260 GGML_UNUSED_VARS(D, A, B);
1261 NO_DEVICE_CODE;
1262#endif // AMD_MFMA_AVAILABLE
1263 }
1264
1265 static __device__ __forceinline__ void mma(
1266 tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
1267#if defined(AMD_MFMA_AVAILABLE)
1268 using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
1269 int32x16_t * acc = (int32x16_t *) D.x;
1270#if defined(CDNA3)
1271 acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
1272 ((int64_t *) B.x)[0],
1273 acc[0],
1274 0, 0, 0);
1275#elif defined(CDNA2) || defined(CDNA)
1276 acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
1277 B.x[0],
1278 acc[0],
1279 0, 0, 0);
1280 acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
1281 B.x[1],
1282 acc[0],
1283 0, 0, 0);
1284#endif // defined(CDNA3)
1285
1286#else
1287 GGML_UNUSED_VARS(D, A, B);
1288 NO_DEVICE_CODE;
1289#endif // AMD_MFMA_AVAILABLE
1290 }
1291
1292 template <typename T1, typename T2, int J, int K>
1293 static __device__ __forceinline__ void mma(
1294 tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
1295 tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
1296 const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
1297 mma(D16[0], A16[0], B);
1298 mma(D16[1], A16[1], B);
1299 }
1300
1301 static __device__ __forceinline__ void mma(
1302 tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
1303#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1304 const int * Axi = (const int *) A.x;
1305 const int * Bxi = (const int *) B.x;
1306 int * Dxi = (int *) D.x;
1307 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1308 "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1309 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1310 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1311 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1312 "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1313 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1314 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1315#else
1316 GGML_UNUSED_VARS(D, A, B);
1317 NO_DEVICE_CODE;
1318#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1319 }
1320
1321 static __device__ __forceinline__ void mma(
1322 tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
1323#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1324 const int * Axi = (const int *) A.x;
1325 const int * Bxi = (const int *) B.x;
1326 int * Dxi = (int *) D.x;
1327 asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1328 "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1329 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1330 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1331 asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1332 "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1333 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1334 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1335#else
1336 GGML_UNUSED_VARS(D, A, B);
1337 NO_DEVICE_CODE;
1338#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1339 }
1340
1341 template <data_layout dl_d, data_layout dl_ab>
1342 static __device__ __forceinline__ void mma(
1343 tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
1344#if defined(AMD_WMMA_AVAILABLE)
1345 using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1346 int32x8_t * acc = (int32x8_t *) D.x;
1347#if defined(RDNA4)
1348 using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1349 int32x2_t * a_vec = (int32x2_t *) A.x;
1350 int32x2_t * b_vec = (int32x2_t *) B.x;
1351
1352 acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1353 true,
1354 a_vec[0],
1355 true,
1356 b_vec[0],
1357 acc[0],
1358 false
1359 );
1360#elif defined(RDNA3)
1361 using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1362 int32x4_t * a_vec = (int32x4_t *) A.x;
1363 int32x4_t * b_vec = (int32x4_t *) B.x;
1364
1365 acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1366 true,
1367 a_vec[0],
1368 true,
1369 b_vec[0],
1370 acc[0],
1371 false
1372 );
1373#endif // RDNA4
1374#else
1375 GGML_UNUSED(D);
1376 GGML_UNUSED(A);
1377 GGML_UNUSED(B);
1378 NO_DEVICE_CODE;
1379#endif // AMD_WMMA_AVAILABLE
1380 }
1381}