1#pragma once
2
3#include "common.cuh"
4#include "vecdotq.cuh"
5#include "mma.cuh"
6
7#include <climits>
8#include <cstdint>
9
10using namespace ggml_cuda_mma;
11
12#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13#define MMQ_ITER_K 256
14#define MMQ_ITER_K_MXFP4_FP4 512
15#define MMQ_NWARPS 8
16
17typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
18typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
19typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
20 float * __restrict__ dst, const int stride, const int i_max, const int j_max);
21
22enum mmq_q8_1_ds_layout {
23 MMQ_Q8_1_DS_LAYOUT_D4,
24 MMQ_Q8_1_DS_LAYOUT_DS4,
25 MMQ_Q8_1_DS_LAYOUT_D2S6,
26};
27
28struct block_q8_1_mmq {
29 // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
30 // The y float data is first grouped as blocks of 128 values.
31 // These blocks are then treated as individual data values and transposed.
32 //
33 // To avoid shared memory bank conflicts each block is padded with 16 bytes.
34 // This padding is also used to store block scales/partial sums.
35 // The scales multiplied with the quantized data are equal to the unquantized values.
36 // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
37 // and are only needed for performance reasons.
38 //
39 // The exact data stored depends on the x data type.
40 union {
41 float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
42 half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
43 half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
44 // stored as d0,d1,s1,s2,s3,s4,s5
45 };
46 int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
47};
48
49struct block_fp4_mmq {
50 uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51 int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
52};
53
54static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
55static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
56static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
57
58static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
59 switch (type_x) {
60 case GGML_TYPE_Q4_0:
61 case GGML_TYPE_Q4_1:
62 return MMQ_Q8_1_DS_LAYOUT_DS4;
63 case GGML_TYPE_Q5_0:
64 return MMQ_Q8_1_DS_LAYOUT_D4;
65 case GGML_TYPE_Q5_1:
66 return MMQ_Q8_1_DS_LAYOUT_DS4;
67 case GGML_TYPE_Q8_0:
68 return MMQ_Q8_1_DS_LAYOUT_D4;
69 case GGML_TYPE_MXFP4:
70 return MMQ_Q8_1_DS_LAYOUT_D4;
71 case GGML_TYPE_Q2_K:
72 return MMQ_Q8_1_DS_LAYOUT_D2S6;
73 case GGML_TYPE_Q3_K:
74 return MMQ_Q8_1_DS_LAYOUT_D4;
75 case GGML_TYPE_Q4_K:
76 case GGML_TYPE_Q5_K:
77 return MMQ_Q8_1_DS_LAYOUT_DS4;
78 case GGML_TYPE_Q6_K:
79 case GGML_TYPE_IQ2_XXS:
80 case GGML_TYPE_IQ2_XS:
81 case GGML_TYPE_IQ2_S:
82 case GGML_TYPE_IQ3_XXS:
83 case GGML_TYPE_IQ3_S:
84 return MMQ_Q8_1_DS_LAYOUT_D4;
85 case GGML_TYPE_IQ1_S:
86 return MMQ_Q8_1_DS_LAYOUT_DS4;
87 case GGML_TYPE_IQ4_XS:
88 case GGML_TYPE_IQ4_NL:
89 return MMQ_Q8_1_DS_LAYOUT_D4;
90 default:
91 GGML_ABORT("fatal error");
92 break;
93 }
94}
95
96struct tile_x_sizes {
97 int qs;
98 int dm;
99 int sc;
100};
101
102static int get_mmq_x_max_host(const int cc) {
103 return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
104 GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
105#ifdef GGML_CUDA_FORCE_MMQ
106 128 : 64;
107#else
108 MMQ_DP4A_MAX_BATCH_SIZE : 64;
109#endif // GGML_CUDA_FORCE_MMQ
110}
111
112static constexpr __device__ int get_mmq_x_max_device() {
113#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
114 return 128;
115#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
116
117#if defined(GGML_USE_HIP)
118 return 64;
119#else // defined(GGML_USE_HIP)
120
121#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
122#ifdef GGML_CUDA_FORCE_MMQ
123 return 128;
124#else // GGML_CUDA_FORCE_MMQ
125 return MMQ_DP4A_MAX_BATCH_SIZE;
126#endif // GGML_CUDA_FORCE_MMQ
127#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
128 return 64;
129#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
130
131#endif // defined(GGML_USE_HIP)
132#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
133}
134
135static int get_mmq_y_host(const int cc) {
136 return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
137 ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
138}
139
140static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141#if defined(BLACKWELL_MMA_AVAILABLE)
142 return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143#else
144 return MMQ_ITER_K;
145#endif // defined(BLACKWELL_MMA_AVAILABLE)
146}
147
148static constexpr __device__ int get_mmq_y_device() {
149#if defined(GGML_USE_HIP)
150#if defined(RDNA1)
151 return 64;
152#else
153 return 128;
154#endif // defined RDNA1
155#else
156#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
157 return 128;
158#else
159 return 64;
160#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
161#endif // defined(GGML_USE_HIP)
162}
163
164// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
165// The K dimension of the tiles has either,
166// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
167// 32 bit elements for the quantized data (does not include scales).
168// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
169// The final tile size in K direction is padded to avoid shared memory bank conflicts,
170// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
171#define MMQ_TILE_NE_K 32
172
173#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
174#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
175#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
176#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
177#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
178#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
179#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
180#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
181#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
182#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
183
184static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
185 switch (type) {
186 case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
187 case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
188 case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
189 case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
190 case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
191 case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
192 case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
193 case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
194 case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
195 case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
196 case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
197 case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
198 case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
199 case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
200 case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
201 case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
202 case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
203 case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
204 case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
205 default: return tile_x_sizes{0, 0, 0};
206 }
207}
208
209#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
211#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
212#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
213#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
214#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
215
216static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
217static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
218static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
219static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
220static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
223
224static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
225 switch (type) {
226 case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
227 case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
228 case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
229 case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
230 case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231 // tile sizes are the same for Q8_1 and FP4 for blackwell
232 case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
233 case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
234 case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
235 case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
236 case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
237 case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
238 case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
239 case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
240 case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
241 case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
242 case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
243 case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
244 case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
245 case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
246 default: return 0;
247 }
248}
249
250// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
251#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
252#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
253
254static int mmq_get_granularity_host(const int mmq_x, const int cc) {
255 if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
256 return mmq_x >= 128 ? 32 : 16;
257 } else if (turing_mma_available(cc) && mmq_x >= 48) {
258 return 16;
259 } else {
260 return 8;
261 }
262}
263
264#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
265static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
266 return mmq_x >= 128 ? 32 : 16;
267}
268#elif defined(TURING_MMA_AVAILABLE)
269static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
270 return mmq_x >= 48 ? 16 : 8;
271}
272#else
273static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
274 return 8;
275}
276#endif // AMD_MFMA_AVAILABLE
277
278#if defined(GGML_USE_HIP)
279static int mmq_get_nwarps_host(const int cc, const int warp_size) {
280 return amd_mfma_available(cc) ? 8 : 256/warp_size;
281}
282#else
283static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
284 return 256/warp_size;
285}
286#endif // (GGML_USE_HIP)
287
288static constexpr __device__ int mmq_get_nwarps_device() {
289#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
290 return 8;
291#else
292 return 256/ggml_cuda_get_physical_warp_size();
293#endif // AMD_MFMA_AVAILABLE
294}
295
296// ------------------------------------------------------------
297
298template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
299 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
300 constexpr int nwarps = mmq_get_nwarps_device();
301 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
302
303#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
304 int * x_qs = (int *) x_tile;
305 float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
306#else
307 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
308 int * x_qs = (int *) x_tile;
309 float * x_df = (float *) (x_qs + txs.qs);
310#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
311
312 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
313 constexpr int nrows = warp_size / threads_per_row;
314 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
315 const int kbx = txi / QI4_0;
316 const int kqsx = txi % QI4_0;
317
318#pragma unroll
319 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
320 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
321
322 if (need_check) {
323 i = min(i, i_max);
324 }
325
326 const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
327 const int qs0 = get_int_b2(bxi->qs, kqsx);
328
329#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
330 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
331 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
332#else
333 x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
334#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
335 }
336
337 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
338 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
339 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
340
341#pragma unroll
342 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
343 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
344
345 if (need_check) {
346 i = min(i, i_max);
347 }
348
349 const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
350
351#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
352 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
353#else
354 x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
355#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
356 }
357}
358
359template <int mmq_x, int mmq_y>
360static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
361 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
362 constexpr int nwarps = mmq_get_nwarps_device();
363 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
364
365 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
366 const int * x_qs = (const int *) x;
367 const float * x_df = (const float *) x_qs + txs.qs;
368 const int * y_qs = (const int *) y + 4;
369 const half2 * y_ds = (const half2 *) y;
370
371// #pragma unroll
372 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
373 const int k0 = k00 + k01;
374
375#pragma unroll
376 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
377 const int j = j0 + threadIdx.y;
378
379#pragma unroll
380 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
381 const int i = i0 + threadIdx.x;
382
383 const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
384
385 int u[2*VDR_Q4_0_Q8_1_MMQ];
386
387#pragma unroll
388 for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
389 u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
390 u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
391 }
392
393 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
394 (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
395 x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
396 }
397 }
398 }
399}
400
401template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
402 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
403 constexpr int nwarps = mmq_get_nwarps_device();
404 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
405
406#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
407 int * x_qs = (int *) x_tile;
408 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
409#else
410 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
411 int * x_qs = (int *) x_tile;
412 half2 * x_dm = (half2 *) (x_qs + txs.qs);
413#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
414
415 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
416 constexpr int nrows = warp_size / threads_per_row;
417 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
418 const int kbx = txi / QI4_1;
419 const int kqsx = txi % QI4_1;
420
421#pragma unroll
422 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
423 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
424
425 if (need_check) {
426 i = min(i, i_max);
427 }
428
429 const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
430 const int qs0 = get_int_b4(bxi->qs, kqsx);
431
432#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
433 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
434 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
435#else
436 x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
437#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
438 }
439
440 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
441 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
442 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
443
444#pragma unroll
445 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
446 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
447
448 if (need_check) {
449 i = min(i, i_max);
450 }
451
452 const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
453
454#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
455 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
456#else
457 x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
458#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
459 }
460}
461
462template <int mmq_x, int mmq_y>
463static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
464 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
465 constexpr int nwarps = mmq_get_nwarps_device();
466 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
467
468 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
469 const int * x_qs = (const int *) x;
470 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
471 const int * y_qs = (const int *) y + 4;
472 const half2 * y_ds = (const half2 *) y;
473
474// #pragma unroll
475 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
476 const int k0 = k00 + k01;
477
478#pragma unroll
479 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
480 const int j = j0 + threadIdx.y;
481
482#pragma unroll
483 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
484 const int i = i0 + threadIdx.x;
485
486 const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
487
488 int u[2*VDR_Q4_1_Q8_1_MMQ];
489
490#pragma unroll
491 for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
492 u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
493 u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
494 }
495
496 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
497 (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
498 x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
499 }
500 }
501 }
502}
503
504template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
505 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
506 constexpr int nwarps = mmq_get_nwarps_device();
507 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
508
509#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
510 int * x_qs = (int *) x_tile;
511 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
512#else
513 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
514 int * x_qs = (int *) x_tile;
515 float * x_df = (float *) (x_qs + txs.qs);
516#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
517
518 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
519 constexpr int nrows = warp_size / threads_per_row;
520 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
521 const int kbx = txi / QI5_0;
522 const int kqsx = txi % QI5_0;
523
524#pragma unroll
525 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
526 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
527
528 if (need_check) {
529 i = min(i, i_max);
530 }
531
532 const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
533
534 const int ql = get_int_b2(bxi->qs, kqsx);
535 const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
536
537 int qs0 = (ql >> 0) & 0x0F0F0F0F;
538 qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
539 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
540 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
541 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
542 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
543
544 int qs1 = (ql >> 4) & 0x0F0F0F0F;
545 qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
546 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
547 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
548 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
549 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
550
551#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
552 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
553 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
554#else
555 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
556 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
557#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
558 }
559
560 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
561 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
562 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
563
564#pragma unroll
565 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
566 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
567
568 if (need_check) {
569 i = min(i, i_max);
570 }
571
572 const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
573
574#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
575 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
576#else
577 x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
578#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
579 }
580}
581
582template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
583 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
584 constexpr int nwarps = mmq_get_nwarps_device();
585 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
586
587#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
588 int * x_qs = (int *) x_tile;
589 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
590#else
591 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
592 int * x_qs = (int *) x_tile;
593 half2 * x_dm = (half2 *) (x_qs + txs.qs);
594#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
595
596 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
597 constexpr int nrows = warp_size / threads_per_row;
598 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
599 const int kbx = txi / QI5_1;
600 const int kqsx = txi % QI5_1;
601
602#pragma unroll
603 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
604 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
605
606 if (need_check) {
607 i = min(i, i_max);
608 }
609
610 const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
611
612 const int ql = get_int_b4(bxi->qs, kqsx);
613 const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
614
615 int qs0 = (ql >> 0) & 0x0F0F0F0F;
616 qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
617 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
618 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
619 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
620
621 int qs1 = (ql >> 4) & 0x0F0F0F0F;
622 qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
623 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
624 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
625 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
626
627#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
628 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
629 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
630#else
631 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
632 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
633#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
634 }
635
636 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
637 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
638 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
639
640#pragma unroll
641 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
642 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
643
644 if (need_check) {
645 i = min(i, i_max);
646 }
647
648 const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
649
650#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
651 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
652#else
653 x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
654#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
655 }
656}
657
658template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
659 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
660 constexpr int nwarps = mmq_get_nwarps_device();
661 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
662
663#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
664 int * x_qs = (int *) x_tile;
665 float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
666#else
667 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
668 int * x_qs = (int *) x_tile;
669 float * x_df = (float *) (x_qs + txs.qs);
670#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
671
672 // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
673 constexpr int threads_per_row = 32;
674 constexpr int nrows = warp_size / threads_per_row;
675 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
676 const int kbx = txi / QI8_0;
677 const int kqsx = txi % QI8_0;
678
679#pragma unroll
680 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
681 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
682
683 if (need_check) {
684 i = min(i, i_max);
685 }
686
687 const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
688
689#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
690 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
691 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
692#else
693 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
694 x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
695#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
696 }
697
698 constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
699 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
700 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
701
702#pragma unroll
703 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
704 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
705
706 if (need_check) {
707 i = min(i, i_max);
708 }
709
710 const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
711
712#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
713 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
714#else
715 x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
716#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
717 }
718}
719
720template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
721 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
722 constexpr int nwarps = mmq_get_nwarps_device();
723 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
724
725#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
726 int * x_qs = (int *) x_tile;
727 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
728#else
729 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
730 int * x_qs = (int *) x_tile;
731 float * x_df = (float *) (x_qs + txs.qs);
732#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
733
734 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
735 constexpr int nrows = warp_size / threads_per_row;
736 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
737 const int kbx = txi / QI_MXFP4;
738 const int kqsx = txi % QI_MXFP4;
739
740#pragma unroll
741 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
742 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
743
744 if (need_check) {
745 i = min(i, i_max);
746 }
747
748 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
749
750 const int aux_q4 = get_int_b1(bxi->qs, kqsx);
751 const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
752 const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
753
754#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
755 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
756 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
757#else
758 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
759 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
760#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
761 }
762
763 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
764 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
765 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
766
767#pragma unroll
768 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
769 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
770
771 if (need_check) {
772 i = min(i, i_max);
773 }
774
775 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
776
777#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
778 x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
779#else
780 x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
781#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
782 }
783}
784
785template <int mmq_y, bool need_check>
786static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
787 int * __restrict__ x_tile,
788 const int kbx0,
789 const int i_max,
790 const int stride) {
791 constexpr int nwarps = mmq_get_nwarps_device();
792 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
793
794 int * x_qs = (int *) x_tile;
795 uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
796
797 const int txi = threadIdx.x;
798
799 constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
800
801 constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
802 constexpr int rows_per_warp = warp_size / threads_per_row;
803 const int kbx = txi % threads_per_row;
804 const int row_in_warp = txi / threads_per_row;
805
806#pragma unroll
807 for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
808 int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
809
810 if constexpr (need_check) {
811 i = min(i, i_max);
812 }
813
814 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
815
816 // quantize_mxfp4_mmq permutes nibbles to match the quantized format
817 const int k0 = kbx * 4;
818 memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
819
820 // Load E8M0 scales: pack 2 consecutive scales into one uint32
821 if (kbx % 2 == 0) {
822 uint32_t e = bxi->e;
823 e |= ((bxi + 1)->e << 8);
824 x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
825 }
826 }
827}
828
829template <int mmq_x, int mmq_y>
830static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
831 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
832 constexpr int nwarps = mmq_get_nwarps_device();
833 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
834
835 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
836 const int * x_qs = (const int *) x;
837 const float * x_df = (const float *) x_qs + txs.qs;
838 const int * y_qs = (const int *) y + 4;
839 const float * y_df = (const float *) y;
840
841// #pragma unroll
842 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
843 const int k0 = k00 + k01;
844
845#pragma unroll
846 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
847 const int j = j0 + threadIdx.y;
848
849#pragma unroll
850 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
851 const int i = i0 + threadIdx.x;
852
853 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
854 (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
855 x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
856 }
857 }
858 }
859}
860
861template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
862static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
863 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
864#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
865 constexpr data_layout input_layout = get_input_data_layout();
866 typedef tile<16, 8, int, input_layout> tile_A;
867 typedef tile<16, 8, int, input_layout> tile_B;
868 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
869
870 constexpr int granularity = mmq_get_granularity_device(mmq_x);
871 constexpr int rows_per_warp = granularity;
872 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
873
874 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
875
876 const int * x_qs = (const int *) x;
877 const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
878 const int * y_qs = (const int *) y + 4;
879 const float * y_df = (const float *) y;
880 const half2 * y_ds = (const half2 *) y;
881
882 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
883
884 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
885 const int k0 = k00 + k01;
886
887 tile_A A[ntx];
888#pragma unroll
889 for (int n = 0; n < ntx; ++n) {
890 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
891 }
892
893#pragma unroll
894 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
895 tile_B B;
896 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
897
898 float dB;
899 const int j = j0 + tile_C::get_j(0);
900 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
901 dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
902 } else {
903 dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
904 }
905
906#pragma unroll
907 for (int n = 0; n < ntx; ++n) {
908 tile_C C;
909 mma(C, A[n], B);
910
911#pragma unroll
912 for (int l = 0; l < tile_C::ne; ++l) {
913 const int i = i0 + n*tile_A::I + tile_C::get_i(l);
914 const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
915 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
916 }
917 }
918 }
919 }
920#else
921 typedef tile<16, 8, int> tile_A;
922 typedef tile< 8, 8, int> tile_B;
923 typedef tile<16, 8, int> tile_C;
924
925 constexpr int granularity = mmq_get_granularity_device(mmq_x);
926 constexpr int rows_per_warp = 2 * granularity;
927 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
928
929 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
930
931 const int * x_qs = (const int *) x;
932 const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
933 const int * y_qs = (const int *) y + 4;
934 const float * y_df = (const float *) y;
935 const half2 * y_ds = (const half2 *) y;
936
937 tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
938 float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
939
940 const int i0 = (threadIdx.y/ntx)*rows_per_warp;
941
942#pragma unroll
943 for (int n = 0; n < ntx; ++n) {
944#pragma unroll
945 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
946 const int k0 = k00 + k01;
947
948 load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
949 }
950
951#pragma unroll
952 for (int l = 0; l < tile_C::ne/2; ++l) {
953 const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
954
955#pragma unroll
956 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
957 const int k0 = k00 + k01;
958
959 dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
960 }
961 }
962 }
963
964#pragma unroll
965 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
966#pragma unroll
967 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
968 tile_B B;
969 float dB[tile_C::ne/2];
970
971 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
972
973#pragma unroll
974 for (int l = 0; l < tile_C::ne/2; ++l) {
975 const int j = j0 + tile_C::get_j(l);
976
977 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
978 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
979 } else {
980 dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
981 }
982 }
983
984#pragma unroll
985 for (int n = 0; n < ntx; ++n) {
986 tile_C C;
987 mma(C, A[n][k01/QI8_0], B);
988
989#pragma unroll
990 for (int l = 0; l < tile_C::ne; ++l) {
991 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
992 }
993 }
994 }
995 }
996#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
997}
998
999template <int mmq_x, int mmq_y>
1000static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001 const int * __restrict__ y,
1002 float * __restrict__ sum,
1003 const int k00) {
1004 typedef tile<16, 8, int> tile_A;
1005 typedef tile<8, 8, int> tile_B;
1006 typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
1007
1008 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009 constexpr int rows_per_warp = 2 * granularity;
1010 constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
1011
1012 y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
1014 // Match layout from load_tiles_mxfp4_fp4
1015 const int * x_qs = (const int *) x;
1016 const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017 const int * y_qs = (const int *) y + 4;
1018 const uint32_t * y_sc = (const uint32_t *) y;
1019
1020 // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021 tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022 uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
1024 // Block scale
1025 // Each thread has to point to a 4 byte scale value
1026 // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
1028 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
1030#pragma unroll
1031 for (int n = 0; n < ntx; ++n) {
1032#pragma unroll
1033 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034 const int k0 = k00 + k01;
1035
1036 load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037 MMQ_MMA_TILE_X_K_FP4);
1038
1039 // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040 const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041 scaleA[n][k01 / (2 * QI_MXFP4)] =
1042 *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043 }
1044 }
1045
1046#pragma unroll
1047 for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048#pragma unroll
1049 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050 tile_B B;
1051 uint32_t scaleB; // 2xN scales
1052
1053 load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
1055 scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
1057#pragma unroll
1058 for (int n = 0; n < ntx; ++n) {
1059 tile_C C;
1060
1061 mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062#pragma unroll
1063 for (int l = 0; l < tile_C::ne; ++l) {
1064 sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065 }
1066 }
1067 }
1068 }
1069}
1070
1071template <int mmq_x, int mmq_y>
1072static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
1073 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1074 constexpr int nwarps = mmq_get_nwarps_device();
1075 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1076
1077 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
1078 const int * x_qs = (const int *) x;
1079 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1080 const int * y_qs = (const int *) y + 4;
1081 const half2 * y_ds = (const half2 *) y;
1082
1083// #pragma unroll
1084 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
1085 const int k0 = k00 + k01;
1086
1087#pragma unroll
1088 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1089 const int j = j0 + threadIdx.y;
1090
1091#pragma unroll
1092 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1093 const int i = i0 + threadIdx.x;
1094
1095 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
1096 (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1097 x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1098 }
1099 }
1100 }
1101}
1102
1103template <int mmq_x, int mmq_y>
1104static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1105 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1106#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1107 constexpr data_layout input_layout = get_input_data_layout();
1108 typedef tile<16, 8, int, input_layout> tile_A;
1109 typedef tile<16, 8, int, input_layout> tile_B;
1110 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1111
1112 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1113 constexpr int rows_per_warp = granularity;
1114 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1115
1116 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1117
1118 const int * x_qs = (const int *) x;
1119 const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
1120 const int * y_qs = (const int *) y + 4;
1121 const half2 * y_dm = (const half2 *) y;
1122
1123 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1124
1125 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1126 const int k0 = k00 + k01;
1127
1128 tile_A A[ntx];
1129#pragma unroll
1130 for (int n = 0; n < ntx; ++n) {
1131 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1132 }
1133
1134#pragma unroll
1135 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1136 tile_B B;
1137 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1138
1139 const int j = j0 + tile_C::get_j(0);
1140 const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1141
1142#pragma unroll
1143 for (int n = 0; n < ntx; ++n) {
1144 tile_C C;
1145 mma(C, A[n], B);
1146
1147#pragma unroll
1148 for (int l = 0; l < tile_C::ne; ++l) {
1149 const int i = i0 + n*tile_A::I + tile_C::get_i(l);
1150 float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1151 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
1152 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
1153 }
1154 }
1155 }
1156 }
1157#else
1158 typedef tile<16, 8, int> tile_A;
1159 typedef tile< 8, 8, int> tile_B;
1160 typedef tile<16, 8, int> tile_C;
1161
1162 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1163 constexpr int rows_per_warp = 2 * granularity;
1164 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1165
1166 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1167
1168 const int * x_qs = (const int *) x;
1169 const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
1170 const int * y_qs = (const int *) y + 4;
1171 const half2 * y_dm = (const half2 *) y;
1172
1173 tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
1174 float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
1175
1176 const int i0 = (threadIdx.y/ntx)*rows_per_warp;
1177
1178#pragma unroll
1179 for (int n = 0; n < ntx; ++n) {
1180#pragma unroll
1181 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1182 const int k0 = k00 + k01;
1183
1184 load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1185 }
1186
1187#pragma unroll
1188 for (int l = 0; l < tile_C::ne/2; ++l) {
1189 const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
1190
1191#pragma unroll
1192 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1193 const int k0 = k00 + k01;
1194
1195 dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1196 }
1197 }
1198 }
1199
1200#pragma unroll
1201 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1202#pragma unroll
1203 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1204 tile_B B;
1205 float2 dsB[tile_C::ne/2];
1206
1207 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
1208
1209#pragma unroll
1210 for (int l = 0; l < tile_C::ne/2; ++l) {
1211 const int j = j0 + tile_C::get_j(l);
1212
1213 dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1214 }
1215
1216#pragma unroll
1217 for (int n = 0; n < ntx; ++n) {
1218 tile_C C;
1219 mma(C, A[n][k01/QI8_1], B);
1220
1221#pragma unroll
1222 for (int l = 0; l < tile_C::ne; ++l) {
1223 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
1224 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
1225 }
1226 }
1227 }
1228 }
1229#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1230}
1231
1232// Used for Q3_K, IQ2_S, and IQ2_XS
1233template <int mmq_x, int mmq_y>
1234static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1235 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1236 constexpr int nwarps = mmq_get_nwarps_device();
1237 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1238
1239 constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1240 const int * x_qs = (const int *) x;
1241 const float * x_df = (const float *) x_qs + txs.qs;
1242 const int * y_qs = (const int *) y + 4;
1243 const float * y_df = (const float *) y;
1244
1245// #pragma unroll
1246 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
1247 const int k0 = k00 + k01;
1248
1249#pragma unroll
1250 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1251 const int j = j0 + threadIdx.y;
1252
1253#pragma unroll
1254 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1255 const int i = i0 + threadIdx.x;
1256
1257 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
1258 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
1259 &y_qs[j*MMQ_TILE_Y_K + k01],
1260 &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
1261 y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1262 }
1263 }
1264 }
1265}
1266
1267// Used for Q3_K, IQ2_S, and IQ2_XS:
1268template <int mmq_x, int mmq_y>
1269static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1270 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1271#if defined(AMD_MFMA_AVAILABLE)
1272 constexpr data_layout input_layout = get_input_data_layout();
1273 typedef tile<16, 8, int, input_layout> tile_A;
1274 typedef tile<16, 8, int, input_layout> tile_B;
1275 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276 typedef tile<64, 2, int, input_layout> tile_load;
1277
1278 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1279 constexpr int rows_per_warp = granularity;
1280 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1281
1282 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1283
1284 const int * x_qs = (const int *) x;
1285 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1286 const int * y_qs = (const int *) y + 4;
1287 const float * y_df = (const float *) y;
1288
1289 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1290
1291 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1292 const int k0 = k00 + k01;
1293
1294 tile_A A[ntx];
1295#pragma unroll
1296 for (int n = 0; n < ntx; ++n) {
1297 load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1298 }
1299
1300#pragma unroll
1301 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1302 tile_B B[1];
1303 load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1304
1305 const int j = j0 + tile_C::get_j(0);
1306 const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1307
1308#pragma unroll
1309 for (int n = 0; n < ntx; ++n) {
1310 tile_C C;
1311 mma(C, A[n], B[0]);
1312
1313#pragma unroll
1314 for (int l = 0; l < tile_C::ne; ++l) {
1315 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317 }
1318 }
1319 }
1320 }
1321#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1322 constexpr data_layout input_layout = get_input_data_layout();
1323 typedef tile<16, 4, int, input_layout> tile_A;
1324 typedef tile<16, 4, int, input_layout> tile_B;
1325 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1326
1327 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1328 constexpr int rows_per_warp = granularity;
1329 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1330
1331 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1332
1333 const int * x_qs = (const int *) x;
1334 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1335 const int * y_qs = (const int *) y + 4;
1336 const float * y_df = (const float *) y;
1337
1338 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1339
1340 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1341 const int k0 = k00 + k01;
1342
1343 tile_A A[ntx];
1344#pragma unroll
1345 for (int n = 0; n < ntx; ++n) {
1346 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347 }
1348
1349#pragma unroll
1350 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351 tile_B B;
1352 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
1354 const int j = j0 + tile_C::get_j(0);
1355 const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1356
1357#pragma unroll
1358 for (int n = 0; n < ntx; ++n) {
1359 tile_C C;
1360 mma(C, A[n], B);
1361
1362#pragma unroll
1363 for (int l = 0; l < tile_C::ne; ++l) {
1364 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1365 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1366 }
1367 }
1368 }
1369 }
1370#elif defined(TURING_MMA_AVAILABLE)
1371
1372 typedef tile<16, 4, int> tile_A;
1373 typedef tile<16, 8, int> tile_A_8;
1374 typedef tile< 8, 4, int> tile_B;
1375 typedef tile<16, 8, int> tile_C;
1376
1377 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1378 constexpr int rows_per_warp = 2 * granularity;
1379 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1380
1381 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1382
1383 const int * x_qs = (const int *) x;
1384 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1385 const int * y_qs = (const int *) y + 4;
1386 const float * y_df = (const float *) y;
1387
1388 const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1389
1390 tile_A A[ntx][8];
1391 float dA[ntx][tile_C::ne/2][8];
1392
1393#pragma unroll
1394 for (int n = 0; n < ntx; ++n) {
1395#pragma unroll
1396 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1397 const int k0 = k00 + k01;
1398
1399 load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1400 }
1401
1402#pragma unroll
1403 for (int l = 0; l < tile_C::ne/2; ++l) {
1404 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1405
1406#pragma unroll
1407 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1408 const int k0 = k00 + k01;
1409
1410 dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
1411 }
1412 }
1413 }
1414
1415#pragma unroll
1416 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1417#pragma unroll
1418 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1419 tile_B B[2];
1420 float dB[tile_C::ne/2];
1421
1422 // Here load_generic is faster than load_ldmatrix.
1423 load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1424 load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1425
1426#pragma unroll
1427 for (int l = 0; l < tile_C::ne/2; ++l) {
1428 const int j = j0 + tile_C::get_j(l);
1429
1430 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1431 }
1432
1433#pragma unroll
1434 for (int n = 0; n < ntx; ++n) {
1435 tile_C C[2];
1436 mma(C[0], A[n][k01/4 + 0], B[0]);
1437 mma(C[1], A[n][k01/4 + 1], B[1]);
1438
1439#pragma unroll
1440 for (int l = 0; l < tile_C::ne; ++l) {
1441 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
1442 }
1443 }
1444 }
1445 }
1446#else
1447 GGML_UNUSED_VARS(x, y, sum, k00);
1448 NO_DEVICE_CODE;
1449#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1450}
1451
1452template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1453 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1454 constexpr int nwarps = mmq_get_nwarps_device();
1455
1456#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1457 int * x_qs = (int *) x_tile;
1458 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1459#else
1460 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1461 int * x_qs = (int *) x_tile;
1462 half2 * x_dm = (half2 *) (x_qs + txs.qs);
1463#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1464
1465 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1466 constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
1467 const int kqsx = threadIdx.x % threads_per_row;
1468
1469#pragma unroll
1470 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1471 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1472
1473 if (need_check) {
1474 i = min(i, i_max);
1475 }
1476
1477 const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
1478
1479 const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1480
1481#pragma unroll
1482 for (int l = 0; l < QR2_K; ++l) {
1483 const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1484
1485 const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1486
1487#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1488 x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1489#else
1490 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1491#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1492 }
1493
1494 const int sc_m = bxi->scales[kqsx];
1495#ifdef FAST_FP16_AVAILABLE
1496 const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
1497#else
1498 const float2 bxi_dmf = __half22float2(bxi->dm);
1499 const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1500#endif // FAST_FP16_AVAILABLE
1501
1502#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1503 x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1504#else
1505 x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1506#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1507 }
1508}
1509
1510template <int mmq_x, int mmq_y>
1511static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1512 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1513 constexpr int nwarps = mmq_get_nwarps_device();
1514 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1515
1516 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1517 const int * x_qs = (const int *) x;
1518 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1519 const int * y_qs = (const int *) y + 4;
1520 const half2 * y_ds = (const half2 *) y;
1521
1522 float2 y_df[mmq_x/nwarps];
1523#pragma unroll
1524 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1525 const int j = j0 + threadIdx.y;
1526
1527 y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1528 }
1529
1530#pragma unroll
1531 for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1532 const int k0 = k00 + k01;
1533
1534#pragma unroll
1535 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1536 const int j = j0 + threadIdx.y;
1537
1538#pragma unroll
1539 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1540 const int i = i0 + threadIdx.x;
1541
1542 constexpr int ns = 2;
1543 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1544 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1545 &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1546 &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1547 }
1548 }
1549 }
1550
1551 // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
1552 // As a workaround 2 separate loops are used instead.
1553#pragma unroll
1554 for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1555 const int k0 = k00 + k01;
1556
1557#pragma unroll
1558 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1559 const int j = j0 + threadIdx.y;
1560
1561#pragma unroll
1562 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1563 const int i = i0 + threadIdx.x;
1564
1565 constexpr int ns = 1;
1566 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1567 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1568 &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1569 &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1570 }
1571 }
1572 }
1573}
1574
1575template <int mmq_x, int mmq_y>
1576static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1577 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1578#if defined(AMD_MFMA_AVAILABLE)
1579 constexpr data_layout input_layout = get_input_data_layout();
1580 typedef tile<16, 8, int, input_layout> tile_A;
1581 typedef tile<16, 8, int, input_layout> tile_B;
1582 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583 typedef tile<64, 2, int, input_layout> tile_load;
1584
1585 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1586 constexpr int rows_per_warp = granularity;
1587 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1588
1589 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1590
1591 const int * x_qs = (const int *) x;
1592 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1593 const int * y_qs = (const int *) y + 4;
1594 const half2 * y_ds = (const half2 *) y;
1595
1596 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1597
1598 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1599 const int k0 = k00 + k01;
1600
1601 tile_A A[ntx];
1602#pragma unroll
1603 for (int n = 0; n < ntx; ++n) {
1604 load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1605 }
1606
1607#pragma unroll
1608 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1609 tile_B B[1];
1610 load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1611
1612 const int j = j0 + tile_C::get_j(0);
1613 const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1614 const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1615 : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1616 : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1617
1618 tile_C Cm;
1619 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1620 tile_A A1;
1621 A1.x[0] = 0x01010101;
1622 A1.x[1] = 0x01010101;
1623 mma(Cm, A1, B[0]);
1624 }
1625
1626#pragma unroll
1627 for (int n = 0; n < ntx; ++n) {
1628 tile_C Cd;
1629 mma(Cd, A[n], B[0]);
1630
1631#pragma unroll
1632 for (int l = 0; l < tile_C::ne; ++l) {
1633 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634 const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635 float tmp = Cd.x[l]*dm.x;
1636 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637 tmp -= Cm.x[l]*dm.y;
1638 }
1639 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641 }
1642 }
1643 }
1644 }
1645#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1646 constexpr data_layout input_layout = get_input_data_layout();
1647 typedef tile<16, 4, int, input_layout> tile_A;
1648 typedef tile<16, 4, int, input_layout> tile_B;
1649 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1650
1651 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1652 constexpr int rows_per_warp = granularity;
1653 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1654
1655 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1656
1657 const int * x_qs = (const int *) x;
1658 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1659 const int * y_qs = (const int *) y + 4;
1660 const half2 * y_ds = (const half2 *) y;
1661
1662 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1663
1664 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1665 const int k0 = k00 + k01;
1666
1667 tile_A A[ntx];
1668#pragma unroll
1669 for (int n = 0; n < ntx; ++n) {
1670 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671 }
1672
1673#pragma unroll
1674 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675 tile_B B;
1676 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
1678 const int j = j0 + tile_C::get_j(0);
1679 const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
1680 const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1681 : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1682 : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1683
1684 tile_C Cm;
1685 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1686 tile_A A1;
1687#pragma unroll
1688 for (int l = 0; l < tile_A::ne; ++l) {
1689 A1.x[l] = 0x01010101;
1690 }
1691 mma(Cm, A1, B);
1692 }
1693
1694#pragma unroll
1695 for (int n = 0; n < ntx; ++n) {
1696 tile_C Cd;
1697 mma(Cd, A[n], B);
1698
1699#pragma unroll
1700 for (int l = 0; l < tile_C::ne; ++l) {
1701 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1702 const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1703 float tmp = Cd.x[l]*dm.x;
1704 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1705 tmp -= Cm.x[l]*dm.y;
1706 }
1707 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1708 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1709 }
1710 }
1711 }
1712 }
1713#elif defined(TURING_MMA_AVAILABLE)
1714
1715 typedef tile<16, 4, int> tile_A;
1716 typedef tile<16, 8, int> tile_A_8;
1717 typedef tile< 8, 4, int> tile_B;
1718 typedef tile<16, 8, int> tile_C;
1719
1720 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1721 constexpr int rows_per_warp = 2 * granularity;
1722 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1723
1724 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1725
1726 const int * x_qs = (const int *) x;
1727 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1728 const int * y_qs = (const int *) y + 4;
1729 const half2 * y_ds = (const half2 *) y;
1730
1731 const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1732
1733 tile_A A[ntx][8];
1734 float dA[ntx][tile_C::ne/2][8];
1735 float mA[ntx][tile_C::ne/2][8];
1736
1737#pragma unroll
1738 for (int n = 0; n < ntx; ++n) {
1739#pragma unroll
1740 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1741 const int k0 = k00 + k01;
1742
1743 load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1744 }
1745 }
1746
1747#pragma unroll
1748 for (int n = 0; n < ntx; ++n) {
1749#pragma unroll
1750 for (int l = 0; l < tile_C::ne/2; ++l) {
1751 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1752
1753#pragma unroll
1754 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
1755 const int k0 = k00 + k01;
1756
1757 const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
1758
1759 dA[n][l][k01/(QI8_1/2)] = dm.x;
1760 mA[n][l][k01/(QI8_1/2)] = dm.y;
1761 }
1762 }
1763 }
1764
1765#pragma unroll
1766 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1767 float2 dB[tile_C::ne/2];
1768
1769#pragma unroll
1770 for (int l = 0; l < tile_C::ne/2; ++l) {
1771 const int j = j0 + tile_C::get_j(l);
1772
1773 dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1774 }
1775
1776#pragma unroll
1777 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1778 tile_B B[2];
1779
1780 // Here load_generic is faster than load_ldmatrix.
1781 load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1782 load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1783
1784 tile_C Cm[2];
1785 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1786 tile_A A1;
1787 A1.x[0] = 0x01010101;
1788 A1.x[1] = 0x01010101;
1789 mma(Cm[0], A1, B[0]);
1790 mma(Cm[1], A1, B[1]);
1791 }
1792
1793#pragma unroll
1794 for (int n = 0; n < ntx; ++n) {
1795 tile_C Cd[2];
1796
1797 mma(Cd[0], A[n][k01/4 + 0], B[0]);
1798 mma(Cd[1], A[n][k01/4 + 1], B[1]);
1799
1800#pragma unroll
1801 for (int l = 0; l < tile_C::ne; ++l) {
1802 float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1803 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1804 tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1805 }
1806 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
1807 }
1808 }
1809 }
1810
1811#pragma unroll
1812 for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
1813 float2 sB[tile_C::ne/2];
1814
1815#pragma unroll
1816 for (int l = 0; l < tile_C::ne/2; ++l) {
1817 const int j = j0 + tile_C::get_j(l);
1818
1819 sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1820 }
1821
1822#pragma unroll
1823 for (int n = 0; n < ntx; ++n) {
1824#pragma unroll
1825 for (int l = 0; l < tile_C::ne; ++l) {
1826 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
1827 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
1828 }
1829 }
1830 }
1831 }
1832#else
1833 GGML_UNUSED_VARS(x, y, sum, k00);
1834 NO_DEVICE_CODE;
1835#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1836}
1837
1838template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1839 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1840 constexpr int nwarps = mmq_get_nwarps_device();
1841 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1842
1843#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1844 int * x_qs = (int *) x_tile;
1845 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1846#else
1847 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1848 int * x_qs = (int *) x_tile;
1849 float * x_df = (float *) (x_qs + txs.qs);
1850 int * x_sc = (int *) (x_df + txs.dm);
1851#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1852
1853 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1854 constexpr int nrows = warp_size / threads_per_row;
1855 const int kqsx = threadIdx.x % threads_per_row;
1856
1857#pragma unroll
1858 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1859 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1860
1861 if (need_check) {
1862 i = min(i, i_max);
1863 }
1864
1865 const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1866
1867 const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1868 const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1869
1870#pragma unroll
1871 for (int l = 0; l < QR3_K; ++l) {
1872 const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1873
1874 const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
1875 const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
1876
1877 const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1878
1879#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1880 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1881#else
1882 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1883#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1884 }
1885 }
1886
1887 constexpr int rows_per_warp = warp_size / 4;
1888#pragma unroll
1889 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1890 int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
1891
1892 if (need_check) {
1893 i = min(i, i_max);
1894 }
1895
1896 const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1897
1898 const int ksc = threadIdx.x % 4;
1899
1900 const int ksc_low = ksc % (QI3_K/8);
1901 const int shift_low = 4 * (ksc / (QI3_K/8));
1902 const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
1903
1904 const int ksc_high = QI3_K/8;
1905 const int shift_high = 2 * ksc;
1906 const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
1907
1908 const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1909
1910#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1911 const int8_t * sc8 = (const int8_t *) ≻
1912 const float d = bxi->d;
1913
1914#pragma unroll
1915 for (int l = 0; l < int(sizeof(int)); ++l) {
1916 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1917 }
1918#else
1919 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1920#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1921 }
1922
1923#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
1924#pragma unroll
1925 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1926 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1927
1928 if (need_check) {
1929 i = min(i, i_max);
1930 }
1931
1932 const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1933
1934 x_df[i] = bxi->d;
1935 }
1936#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
1937}
1938
1939template <int mmq_x, int mmq_y>
1940static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1941 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1942 constexpr int nwarps = mmq_get_nwarps_device();
1943 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1944
1945 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1946 const int * x_qs = (const int *) x;
1947 const float * x_df = (const float *) x_qs + txs.qs;
1948 const int * x_sc = (const int *) x_df + txs.dm;
1949 const int * y_qs = (const int *) y + 4;
1950 const float * y_df = (const float *) y;
1951
1952// #pragma unroll
1953 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1954 const int k0 = k00 + k01;
1955
1956#pragma unroll
1957 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1958 const int j = j0 + threadIdx.y;
1959
1960#pragma unroll
1961 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1962 const int i = i0 + threadIdx.x;
1963
1964 const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
1965
1966 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
1967 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1968 x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1969 }
1970 }
1971 }
1972}
1973
1974static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
1975 // scale arrangement after the following two lines:
1976 // - ksc == 0: sc0, sc1, sc2, sc3
1977 // - ksc == 1: sc4, sc5, sc6, sc7
1978 // - ksc == 2: m0, m1, m2, m3
1979 // - ksc == 3: m4, m5, m6, m7
1980 return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
1981 ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
1982}
1983
1984template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1985 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1986 constexpr int nwarps = mmq_get_nwarps_device();
1987 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1988
1989#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1990 int * x_qs = (int *) x_tile;
1991 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1992#else
1993 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1994 int * x_qs = (int *) x_tile;
1995 half2 * x_dm = (half2 *) (x_qs + txs.qs);
1996 int * x_sc = (int *) (x_dm + txs.dm);
1997#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1998
1999 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
2000 constexpr int nrows = warp_size / threads_per_row;
2001 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2002
2003#pragma unroll
2004 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2005 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2006
2007 if (need_check) {
2008 i = min(i, i_max);
2009 }
2010
2011 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2012 const int qs0 = get_int_b4(bxi->qs, txi);
2013
2014#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2015 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
2016 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
2017#else
2018 x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
2019#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2020 }
2021
2022#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2023 constexpr int rows_per_warp = warp_size / 2;
2024#pragma unroll
2025 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2026#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2027 // Need if on AMD instead of % because warp_size == 64
2028 // This causes double work and throughput loss (MI300X)
2029 // H100 loses about 100 t/s with 'if' condition over '%'
2030 int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
2031 if (i < mmq_y) {
2032#else
2033 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
2034 {
2035#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2036 if (need_check) {
2037 i = min(i, i_max);
2038 }
2039
2040 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2041
2042 const int * scales = (const int *) bxi->scales;
2043 const int ksc = threadIdx.x % 2;
2044
2045 const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
2046 const int m32 = unpack_scales_q45_K(scales, ksc + 2);
2047
2048 const uint8_t * sc8 = (const uint8_t *) &sc32;
2049 const uint8_t * m8 = (const uint8_t *) &m32;
2050
2051 const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
2052
2053 #pragma unroll
2054 for (int l = 0; l < sizeof(int); ++l) {
2055 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2056 }
2057 }
2058 }
2059#else
2060#pragma unroll
2061 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2062 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2063
2064 if (need_check) {
2065 i = min(i, i_max);
2066 }
2067
2068 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2069
2070 x_dm[i] = bxi->dm;
2071 }
2072 constexpr int rows_per_warp = warp_size / 4;
2073#pragma unroll
2074 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2075 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2076
2077 if (need_check) {
2078 i = min(i, i_max);
2079 }
2080
2081 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
2082
2083 const int * scales = (const int *) bxi->scales;
2084
2085 const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
2086 const int scales8 = unpack_scales_q45_K(scales, ksc);
2087
2088 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
2089 }
2090#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2091}
2092
2093template <int mmq_x, int mmq_y>
2094static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
2095 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2096 constexpr int nwarps = mmq_get_nwarps_device();
2097 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2098
2099 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
2100 const int * x_qs = (const int *) x;
2101 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
2102 const int * x_sc = (const int *) x_dm + txs.dm;
2103 const int * y_qs = (const int *) y + 4;
2104 const half2 * y_ds = (const half2 *) y;
2105
2106// #pragma unroll
2107 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
2108 const int k0 = k00 + k01;
2109
2110#pragma unroll
2111 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2112 const int j = j0 + threadIdx.y;
2113
2114#pragma unroll
2115 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2116 const int i = i0 + threadIdx.x;
2117
2118 const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
2119
2120 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
2121 &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2122 x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
2123 }
2124 }
2125 }
2126}
2127
2128template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
2129 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2130 constexpr int nwarps = mmq_get_nwarps_device();
2131 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2132
2133#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2134 int * x_qs = (int *) x_tile;
2135 half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2136#else
2137 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
2138 int * x_qs = (int *) x_tile;
2139 half2 * x_dm = (half2 *) (x_qs + txs.qs);
2140 int * x_sc = (int *) (x_dm + txs.dm);
2141#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2142
2143 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
2144 constexpr int nrows = warp_size / threads_per_row;
2145 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2146
2147#pragma unroll
2148 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2149 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2150
2151 if (need_check) {
2152 i = min(i, i_max);
2153 }
2154
2155 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2156 const int ky = QR5_K*txi;
2157
2158 const int ql = get_int_b4(bxi->qs, txi);
2159 const int ql0 = (ql >> 0) & 0x0F0F0F0F;
2160 const int ql1 = (ql >> 4) & 0x0F0F0F0F;
2161
2162 const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
2163 const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
2164 const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
2165
2166 const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
2167 const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
2168
2169#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2170 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
2171 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
2172#else
2173 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
2174 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
2175#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2176 }
2177
2178#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2179 constexpr int rows_per_warp = warp_size / 2;
2180#pragma unroll
2181 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2182#if defined(AMD_MFMA_AVAILABLE)
2183 // Need if on AMD instead of % because warp_size == 64
2184 // This causes double work and throughput loss (MI300X)
2185 // H100 loses about 100 t/s with 'if' condition over '%'
2186 int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
2187 if (i < mmq_y) {
2188#else
2189 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
2190 {
2191#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2192 if (need_check) {
2193 i = min(i, i_max);
2194 }
2195
2196 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2197
2198 const int * scales = (const int *) bxi->scales;
2199 const int ksc = threadIdx.x % 2;
2200
2201 const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
2202 const int m32 = unpack_scales_q45_K(scales, ksc + 2);
2203
2204 const uint8_t * sc8 = (const uint8_t *) &sc32;
2205 const uint8_t * m8 = (const uint8_t *) &m32;
2206
2207 const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
2208
2209#pragma unroll
2210 for (int l = 0; l < int(sizeof(int)); ++l) {
2211 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2212 }
2213 }
2214 }
2215#else
2216#pragma unroll
2217 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2218 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2219
2220 if (need_check) {
2221 i = min(i, i_max);
2222 }
2223
2224 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2225
2226 x_dm[i] = bxi->dm;
2227 }
2228
2229 constexpr int rows_per_warp = warp_size / 4;
2230#pragma unroll
2231 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2232 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2233
2234 if (need_check) {
2235 i = min(i, i_max);
2236 }
2237
2238 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2239
2240 const int * scales = (const int *) bxi->scales;
2241
2242 const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
2243 const int scales8 = unpack_scales_q45_K(scales, ksc);
2244
2245 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
2246 }
2247#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2248}
2249
2250template <int mmq_x, int mmq_y>
2251static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
2252 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2253 constexpr int nwarps = mmq_get_nwarps_device();
2254 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2255
2256 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
2257 const int * x_qs = (const int *) x;
2258 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
2259 const int * x_sc = (const int *) x_dm + txs.dm;
2260 const int * y_qs = (const int *) y + 4;
2261 const half2 * y_ds = (const half2 *) y;
2262
2263// #pragma unroll
2264 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
2265 const int k0 = k00 + k01;
2266
2267#pragma unroll
2268 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2269 const int j = j0 + threadIdx.y;
2270
2271#pragma unroll
2272 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2273 const int i = i0 + threadIdx.x;
2274
2275 const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
2276
2277 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
2278 &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2279 x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
2280 }
2281 }
2282 }
2283}
2284
2285template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
2286 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2287 constexpr int nwarps = mmq_get_nwarps_device();
2288 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2289
2290#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2291 int * x_qs = (int *) x_tile;
2292 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2293 int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
2294#else
2295 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
2296 int * x_qs = (int *) x_tile;
2297 float * x_df = (float *) (x_qs + txs.qs);
2298 int * x_sc = (int *) (x_df + txs.dm);
2299#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2300
2301 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2302 constexpr int nrows = warp_size / threads_per_row;
2303 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2304
2305#pragma unroll
2306 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2307 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2308
2309 if (need_check) {
2310 i = min(i, i_max);
2311 }
2312
2313 const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2314
2315 const int ql = get_int_b2(bxi->ql, txi);
2316 const int ql0 = (ql >> 0) & 0x0F0F0F0F;
2317 const int ql1 = (ql >> 4) & 0x0F0F0F0F;
2318
2319 const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
2320 const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
2321 const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
2322
2323 const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2324 const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2325
2326#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2327 x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2328 x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2329#else
2330 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2331 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2332#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2333 }
2334
2335#pragma unroll
2336 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2337 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2338
2339 if (need_check) {
2340 i = min(i, i_max);
2341 }
2342
2343 const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2344
2345#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2346 x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
2347#else
2348 x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2349#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2350 }
2351
2352 constexpr int rows_per_warp = warp_size / 4;
2353#pragma unroll
2354 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2355 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2356
2357 if (need_check) {
2358 i = min(i, i_max);
2359 }
2360
2361 const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2362
2363#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2364 x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2365#else
2366 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2367#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2368 }
2369}
2370
2371template <int mmq_x, int mmq_y>
2372static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
2373 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2374 constexpr int nwarps = mmq_get_nwarps_device();
2375 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2376
2377 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
2378 const int * x_qs = (const int *) x;
2379 const float * x_df = (const float *) x_qs + txs.qs;
2380 const int * x_sc = (const int *) x_df + txs.dm;
2381 const int * y_qs = (const int *) y + 4;
2382 const float * y_df = (const float *) y;
2383
2384// #pragma unroll
2385 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2386 const int k0 = k00 + k01;
2387
2388#pragma unroll
2389 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2390 const int j = j0 + threadIdx.y;
2391
2392#pragma unroll
2393 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2394 const int i = i0 + threadIdx.x;
2395
2396 const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
2397
2398 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
2399 &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2400 x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2401 }
2402 }
2403 }
2404}
2405
2406template <int mmq_x, int mmq_y>
2407static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2408 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2409#if defined(AMD_MFMA_AVAILABLE)
2410 constexpr data_layout input_layout = get_input_data_layout();
2411 typedef tile<16, 8, int, input_layout> tile_A;
2412 typedef tile<16, 8, int, input_layout> tile_B;
2413 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414 typedef tile<64, 2, int, input_layout> tile_load;
2415
2416 constexpr int granularity = mmq_get_granularity_device(mmq_x);
2417 constexpr int rows_per_warp = granularity;
2418 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2419
2420 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2421
2422 const int * x_qs = (const int *) x;
2423 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2424 const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2425 const int * y_qs = (const int *) y + 4;
2426 const float * y_df = (const float *) y;
2427
2428 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2429
2430 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2431 const int k0 = k00 + k01;
2432
2433 tile_A A[ntx];
2434#pragma unroll
2435 for (int n = 0; n < ntx; ++n) {
2436 load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2437 }
2438
2439#pragma unroll
2440 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2441 tile_B B[1];
2442 load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2443
2444 const int j = j0 + tile_C::get_j(0);
2445 const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2446
2447#pragma unroll
2448 for (int n = 0; n < ntx; ++n) {
2449 tile_C C;
2450 mma(C, A[n], B[0]);
2451
2452#pragma unroll
2453 for (int l = 0; l < tile_C::ne; ++l) {
2454 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455 const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457 }
2458 }
2459 }
2460 }
2461#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2462 constexpr data_layout input_layout = get_input_data_layout();
2463 typedef tile<16, 4, int, input_layout> tile_A;
2464 typedef tile<16, 4, int, input_layout> tile_B;
2465 typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2466
2467 constexpr int granularity = mmq_get_granularity_device(mmq_x);
2468 constexpr int rows_per_warp = granularity;
2469 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2470
2471 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2472
2473 const int * x_qs = (const int *) x;
2474 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2475 const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2476 const int * y_qs = (const int *) y + 4;
2477 const float * y_df = (const float *) y;
2478
2479 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2480
2481 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2482 const int k0 = k00 + k01;
2483
2484 tile_A A[ntx];
2485#pragma unroll
2486 for (int n = 0; n < ntx; ++n) {
2487 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488 }
2489
2490#pragma unroll
2491 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492 tile_B B;
2493 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
2495 const int j = j0 + tile_C::get_j(0);
2496 const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2497
2498#pragma unroll
2499 for (int n = 0; n < ntx; ++n) {
2500 tile_C C;
2501 mma(C, A[n], B);
2502
2503#pragma unroll
2504 for (int l = 0; l < tile_C::ne; ++l) {
2505 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2506 const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2507 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2508 }
2509 }
2510 }
2511 }
2512#elif defined(TURING_MMA_AVAILABLE)
2513
2514 typedef tile<16, 4, int> tile_A;
2515 typedef tile< 8, 4, int> tile_B;
2516 typedef tile<16, 8, int> tile_C;
2517
2518 constexpr int granularity = mmq_get_granularity_device(mmq_x);
2519 constexpr int rows_per_warp = 2 * granularity;
2520 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2521
2522 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2523
2524 const int * x_qs = (const int *) x;
2525 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2526 const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2527 const int * y_qs = (const int *) y + 4;
2528 const float * y_df = (const float *) y;
2529
2530 const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
2531
2532 tile_A A[ntx][8];
2533 int scA[ntx][tile_C::ne/2][8];
2534 float dA[ntx][tile_C::ne/2];
2535
2536#pragma unroll
2537 for (int n = 0; n < ntx; ++n) {
2538#pragma unroll
2539 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
2540 const int k0 = k00 + k01;
2541
2542 load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
2543 load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
2544 }
2545
2546#pragma unroll
2547 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
2548 const int k0 = k00 + k01;
2549
2550#pragma unroll
2551 for (int l = 0; l < tile_C::ne/2; ++l) {
2552 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2553
2554 const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
2555 const int8_t * sc = (const int8_t *) &sc_packed;
2556
2557#pragma unroll
2558 for (int ksc = 0; ksc < sizeof(int); ++ksc) {
2559 scA[n][l][k01/4 + ksc] = sc[ksc];
2560 }
2561 }
2562 }
2563
2564#pragma unroll
2565 for (int l = 0; l < tile_C::ne/2; ++l) {
2566 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2567
2568 dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2569 }
2570 }
2571
2572#pragma unroll
2573 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2574 float tmp[ntx][tile_C::ne] = {{0.0f}};
2575
2576#pragma unroll
2577 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
2578 tile_B B[2];
2579 float dB[tile_C::ne/2];
2580
2581 // Here load_generic is faster than load_ldmatrix.
2582 load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
2583 load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2584
2585#pragma unroll
2586 for (int l = 0; l < tile_C::ne/2; ++l) {
2587 const int j = j0 + tile_C::get_j(l);
2588
2589 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2590 }
2591
2592#pragma unroll
2593 for (int n = 0; n < ntx; ++n) {
2594 tile_C C[2];
2595 mma(C[0], A[n][k01/4 + 0], B[0]);
2596 mma(C[1], A[n][k01/4 + 1], B[1]);
2597
2598#pragma unroll
2599 for (int l = 0; l < tile_C::ne; ++l) {
2600 tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
2601 }
2602 }
2603 }
2604
2605#pragma unroll
2606 for (int n = 0; n < ntx; ++n) {
2607#pragma unroll
2608 for (int l = 0; l < tile_C::ne; ++l) {
2609 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
2610 }
2611 }
2612 }
2613#else
2614 GGML_UNUSED_VARS(x, y, sum, k00);
2615 NO_DEVICE_CODE;
2616#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
2617}
2618
2619template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2620 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2621 constexpr int nwarps = mmq_get_nwarps_device();
2622 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2623
2624#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2625 int * x_qs = (int *) x_tile;
2626 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2627#else
2628 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2629 int * x_qs = (int *) x_tile;
2630 float * x_df = (float *) (x_qs + txs.qs);
2631#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2632
2633 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2634 constexpr int nrows = warp_size / threads_per_row;
2635 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2636 const int kbx = txi / QI4_NL;
2637 const int kqsx = txi % QI4_NL;
2638
2639#pragma unroll
2640 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2641 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2642
2643 if (need_check) {
2644 i = min(i, i_max);
2645 }
2646
2647 const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
2648
2649 const int aux_q4 = get_int_b2(bxi->qs, kqsx);
2650 const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2651 const int k0 = kbx * (2 * QI4_NL) + kqsx;
2652
2653#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2654 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2655 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2656#else
2657 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2658 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2659#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2660 }
2661
2662 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
2663 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
2664 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
2665
2666#pragma unroll
2667 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2668 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
2669
2670 if (need_check) {
2671 i = min(i, i_max);
2672 }
2673
2674 const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2675
2676#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2677 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2678#else
2679 x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2680#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2681 }
2682}
2683
2684template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
2685 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2686 constexpr int nwarps = mmq_get_nwarps_device();
2687 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2688
2689#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2690 int * x_qs = (int *) x_tile;
2691 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2692#else
2693 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2694 int * x_qs = (int *) x_tile;
2695 float * x_df = (float *) (x_qs + txs.qs);
2696#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2697
2698 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2699 constexpr int nrows = warp_size / threads_per_row;
2700 const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2701
2702#pragma unroll
2703 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2704 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2705
2706 if (need_check) {
2707 i = min(i, i_max);
2708 }
2709
2710 const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
2711
2712 const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
2713 const uint8_t * aux8 = (const uint8_t *) &q2;
2714 const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
2715
2716#pragma unroll
2717 for (int l = 0; l < QR2_XXS; ++l) {
2718 const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
2719 const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
2720
2721 const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
2722 const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
2723
2724 const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2725 const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2726
2727#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2728 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2729 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2730#else
2731 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2732 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2733#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2734 }
2735
2736 const int ls = aux32 >> 28;
2737 const float d = bxi->d;
2738#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2739 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2740#else
2741 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2742#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2743 }
2744}
2745
2746template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
2747 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2748 constexpr int nwarps = mmq_get_nwarps_device();
2749 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2750
2751#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2752 int * x_qs = (int *) x_tile;
2753 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2754#else
2755 constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2756 int * x_qs = (int *) x_tile;
2757 float * x_df = (float *) (x_qs + txs.qs);
2758#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2759
2760 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2761 constexpr int nrows = warp_size / threads_per_row;
2762 const int kqsx = threadIdx.x % threads_per_row;
2763
2764#pragma unroll
2765 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2766 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2767
2768 if (need_check) {
2769 i = min(i, i_max);
2770 }
2771
2772 const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
2773
2774 const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2775 const uint16_t * q2 = (const uint16_t *) &q2_packed;
2776
2777 #pragma unroll
2778 for (int l = 0; l < QR2_XS; ++l) {
2779 const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
2780 const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
2781
2782 const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2783 const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2784
2785#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2786 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2787 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2788#else
2789 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2790 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2791#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2792 }
2793
2794 const int ls = bxi->scales[kqsx];
2795 const float d = bxi->d;
2796#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2797 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2798 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2799#else
2800 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2801 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2802#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2803 }
2804}
2805
2806template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2807 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2808 constexpr int nwarps = mmq_get_nwarps_device();
2809 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2810
2811#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2812 int * x_qs = (int *) x_tile;
2813 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2814#else
2815 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2816 int * x_qs = (int *) x_tile;
2817 float * x_df = (float *) (x_qs + txs.qs);
2818#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2819 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2820 constexpr int nrows = warp_size / threads_per_row;
2821 const int kqsx = threadIdx.x % threads_per_row;
2822
2823#pragma unroll
2824 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2825 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2826
2827 if (need_check) {
2828 i = min(i, i_max);
2829 }
2830
2831 const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
2832
2833 const int qs_packed = get_int_b2(bxi->qs, kqsx);
2834 const uint8_t * qs = (const uint8_t *) &qs_packed;
2835
2836 const int qh = bxi->qh[kqsx];
2837
2838 const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
2839 const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
2840
2841#pragma unroll
2842 for (int l = 0; l < QR2_S; ++l) {
2843 const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
2844
2845 const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
2846 const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
2847
2848 const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2849 const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2850
2851#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2852 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2853 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2854#else
2855 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2856 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2857#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2858 }
2859
2860 const int ls = bxi->scales[kqsx];
2861 const float d = bxi->d;
2862#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2863 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2864 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2865#else
2866 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2867 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2868#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2869 }
2870}
2871
2872template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2873 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2874 constexpr int nwarps = mmq_get_nwarps_device();
2875 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2876
2877#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2878 int * x_qs = (int *) x_tile;
2879 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2880#else
2881 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2882 int * x_qs = (int *) x_tile;
2883 float * x_df = (float *) (x_qs + txs.qs);
2884#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2885
2886 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2887 constexpr int nrows = warp_size / threads_per_row;
2888 const int kqsx = threadIdx.x % threads_per_row;
2889
2890#pragma unroll
2891 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2892 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2893
2894 if (need_check) {
2895 i = min(i, i_max);
2896 }
2897
2898 const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
2899
2900 const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2901 const uint8_t * q3 = (const uint8_t *) &q3_packed;
2902 const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
2903
2904#pragma unroll
2905 for (int l = 0; l < QR3_XXS; ++l) {
2906 const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
2907
2908 const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
2909
2910 const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2911 const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2912
2913#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2914 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2915 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2916#else
2917 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2918 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2919#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2920 }
2921
2922 const int ls = aux32 >> 28;
2923 const float d = bxi->d;
2924#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2925 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2926#else
2927 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2928#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2929 }
2930}
2931
2932template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2933 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2934 constexpr int nwarps = mmq_get_nwarps_device();
2935 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2936
2937#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2938 int * x_qs = (int *) x_tile;
2939 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2940#else
2941 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2942 int * x_qs = (int *) x_tile;
2943 float * x_df = (float *) (x_qs + txs.qs);
2944#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2945
2946 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2947 constexpr int nrows = warp_size / threads_per_row;
2948 const int kqsx = threadIdx.x % threads_per_row;
2949
2950#pragma unroll
2951 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2952 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2953
2954 if (need_check) {
2955 i = min(i, i_max);
2956 }
2957
2958 const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
2959
2960 const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2961 const uint8_t * qs = (const uint8_t *) &qs_packed;
2962
2963 const int qh = bxi->qh[kqsx];
2964
2965 const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
2966 const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
2967
2968#pragma unroll
2969 for (int l = 0; l < QR3_S; ++l) {
2970 const int2 grid_pos = make_int2(
2971 iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
2972 iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
2973
2974 const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
2975 const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
2976
2977 const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2978 const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2979
2980#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2981 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2982 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2983#else
2984 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2985 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2986#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2987 }
2988
2989 const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2990 const float d = bxi->d;
2991#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2992 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2993#else
2994 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2995#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2996 }
2997}
2998
2999template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
3000 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
3001 constexpr int nwarps = mmq_get_nwarps_device();
3002 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3003
3004#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3005 int * x_qs = (int *) x_tile;
3006 half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
3007#else
3008 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
3009 int * x_qs = (int *) x_tile;
3010 half2 * x_ds = (half2 *) (x_qs + txs.qs);
3011#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3012
3013 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
3014 constexpr int nrows = warp_size / threads_per_row;
3015 const int kqsx = threadIdx.x % threads_per_row;
3016
3017#pragma unroll
3018 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
3019 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
3020
3021 if (need_check) {
3022 i = min(i, i_max);
3023 }
3024
3025 const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
3026
3027 const int qs_packed = get_int_b2(bxi->qs, kqsx);
3028 const uint8_t * qs = (const uint8_t *) &qs_packed;
3029
3030 const int qh = bxi->qh[kqsx];
3031
3032 #pragma unroll
3033 for (int l = 0; l < QR1_S/2; ++l) {
3034 const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
3035
3036 const int grid0 = (grid >> 0) & 0x0F0F0F0F;
3037 const int grid1 = (grid >> 4) & 0x0F0F0F0F;
3038
3039#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3040 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
3041 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
3042#else
3043 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
3044 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
3045#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3046 }
3047
3048 const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
3049 const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
3050
3051#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3052 x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
3053#else
3054 x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
3055#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3056 }
3057}
3058
3059template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
3060 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
3061 constexpr int nwarps = mmq_get_nwarps_device();
3062 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3063
3064#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3065 int * x_qs = (int *) x_tile;
3066 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
3067#else
3068 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
3069 int * x_qs = (int *) x_tile;
3070 float * x_df = (float *) (x_qs + txs.qs);
3071#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3072
3073 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
3074 constexpr int nrows = warp_size / threads_per_row;
3075 const int kqsx = threadIdx.x % threads_per_row;
3076
3077#pragma unroll
3078 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
3079 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
3080
3081 if (need_check) {
3082 i = min(i, i_max);
3083 }
3084
3085 const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
3086
3087 const int aux_q4 = get_int_b4(bxi->qs, kqsx);
3088 const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
3089 const int k0 = 8 * (kqsx / 4) + kqsx % 4;
3090
3091#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3092 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
3093 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
3094#else
3095 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
3096 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
3097#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3098 }
3099
3100 constexpr int rows_per_warp = warp_size / 8;
3101#pragma unroll
3102 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
3103 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
3104
3105 if (need_check) {
3106 i = min(i, i_max);
3107 }
3108
3109 const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
3110
3111 const float d = __half2float(bxi->d);
3112
3113 const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
3114 | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
3115
3116#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3117 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
3118#else
3119 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
3120#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3121 }
3122}
3123
3124template<int mmq_x, int mmq_y, bool need_check>
3125static __device__ __forceinline__ void mmq_write_back_dp4a(
3126 const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
3127 const int stride, const int i_max, const int j_max) {
3128 constexpr int nwarps = mmq_get_nwarps_device();
3129 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3130
3131#pragma unroll
3132 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3133 const int j = j0 + threadIdx.y;
3134
3135 if (j > j_max) {
3136 return;
3137 }
3138
3139#pragma unroll
3140 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3141 const int i = i0 + threadIdx.x;
3142
3143 if (need_check && i > i_max) {
3144 continue;
3145 }
3146
3147 dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3148 }
3149 }
3150}
3151
3152template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
3153static __device__ __forceinline__ void mmq_write_back_mma(
3154 const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
3155 const int stride, const int i_max, const int j_max) {
3156
3157 constexpr int granularity = mmq_get_granularity_device(mmq_x);
3158 constexpr int nwarps = mmq_get_nwarps_device();
3159
3160#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3161 constexpr int tileC_IJ = mmq_get_granularity_device(0);
3162 typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
3163 constexpr int rows_per_warp = granularity;
3164#else
3165 typedef tile<16, 8, int> tile_C;
3166 constexpr int rows_per_warp = 2 * granularity;
3167#endif // defined(AMD_MFMA_AVAILABLE)
3168 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
3169
3170 const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
3171#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3172 static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
3173#else
3174 GGML_UNUSED(nwarps);
3175#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3176
3177#pragma unroll
3178 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
3179#pragma unroll
3180 for (int n = 0; n < ntx; ++n) {
3181#pragma unroll
3182 for (int l = 0; l < tile_C::ne; ++l) {
3183 const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
3184
3185 if (j > j_max) {
3186 continue;
3187 }
3188
3189 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
3190
3191 if (need_check && i > i_max) {
3192 continue;
3193 }
3194
3195 dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
3196 }
3197 }
3198 }
3199}
3200
3201// -------------------------------------------------------------------------------------------------------------------------------------
3202
3203template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
3204struct mmq_type_traits;
3205
3206template <int mmq_x, int mmq_y, bool need_check>
3207struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
3208 static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
3209 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
3210 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
3211 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
3212};
3213
3214template <int mmq_x, int mmq_y, bool need_check>
3215struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
3216 static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
3217 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
3218 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3219 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
3220};
3221
3222template <int mmq_x, int mmq_y, bool need_check>
3223struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
3224 static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
3225 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
3226 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3227 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3228};
3229
3230template <int mmq_x, int mmq_y, bool need_check>
3231struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
3232 static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
3233 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
3234 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3235 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
3236};
3237
3238template <int mmq_x, int mmq_y, bool need_check>
3239struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
3240 static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
3241 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
3242 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3243 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3244};
3245
3246template <int mmq_x, int mmq_y, bool need_check>
3247struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3248 static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
3249#ifdef BLACKWELL_MMA_AVAILABLE
3250 static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3251 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3252#else
3253 static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
3254 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3255#endif // BLACKWELL_MMA_AVAILABLE
3256 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3257};
3258
3259template <int mmq_x, int mmq_y, bool need_check>
3260struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
3261 static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
3262 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
3263 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
3264 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
3265};
3266
3267template <int mmq_x, int mmq_y, bool need_check>
3268struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
3269 static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
3270 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
3271 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3272 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
3273};
3274
3275template <int mmq_x, int mmq_y, bool need_check>
3276struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
3277 static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
3278 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
3279 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3280 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
3281};
3282
3283template <int mmq_x, int mmq_y, bool need_check>
3284struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
3285 static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
3286 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
3287 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3288 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
3289};
3290
3291template <int mmq_x, int mmq_y, bool need_check>
3292struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
3293 static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
3294 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
3295 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
3296 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
3297};
3298
3299template <int mmq_x, int mmq_y, bool need_check>
3300struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
3301 static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
3302 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
3303 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3304 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3305};
3306
3307template <int mmq_x, int mmq_y, bool need_check>
3308struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
3309 static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
3310 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
3311 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3312 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3313};
3314
3315template <int mmq_x, int mmq_y, bool need_check>
3316struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
3317 static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
3318 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
3319 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3320 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3321};
3322
3323template <int mmq_x, int mmq_y, bool need_check>
3324struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
3325 static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
3326 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
3327 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3328 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3329};
3330
3331template <int mmq_x, int mmq_y, bool need_check>
3332struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
3333 static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
3334 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
3335 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3336 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3337};
3338
3339template <int mmq_x, int mmq_y, bool need_check>
3340struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
3341 static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
3342 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
3343 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3344 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
3345};
3346
3347template <int mmq_x, int mmq_y, bool need_check>
3348struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
3349 static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
3350 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
3351 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3352 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3353};
3354
3355template <int mmq_x, int mmq_y, bool need_check>
3356struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
3357 static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
3358 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
3359 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3360 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3361};
3362
3363template <ggml_type type, int mmq_x, bool need_check, bool fixup>
3364static __device__ __forceinline__ void mul_mat_q_process_tile(
3365 const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
3366 const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3367 const int stride_row_x, const int ncols_y, const int stride_col_dst,
3368 const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
3369
3370 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3371 constexpr int nwarps = mmq_get_nwarps_device();
3372 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3373 constexpr int mmq_y = get_mmq_y_device();
3374 constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
3375
3376 extern __shared__ int data_mul_mat_q[];
3377 int * tile_y = data_mul_mat_q + mmq_x;
3378 int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3379
3380#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3381 constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3382 constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3383#else
3384 constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3385 constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3386#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3387
3388#if defined(BLACKWELL_MMA_AVAILABLE)
3389 // FP4 tile stores 8 blocks
3390 constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
3391#else
3392 constexpr int ne_block = 4 * QK8_1;
3393#endif // defined(BLACKWELL_MMA_AVAILABLE)
3394
3395 constexpr int ITER_K = get_iter_k(type);
3396 constexpr int blocks_per_iter = ITER_K / qk;
3397
3398 float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3399
3400 constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
3401
3402 for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
3403 load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
3404 {
3405 const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
3406#pragma unroll
3407 for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3408 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3409
3410 tile_y[l] = by0[l];
3411 }
3412 }
3413
3414 __syncthreads();
3415
3416 vec_dot(tile_x, tile_y, sum, 0);
3417
3418 __syncthreads();
3419
3420 {
3421 const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
3422#pragma unroll
3423 for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3424 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3425
3426 tile_y[l] = by0[l];
3427 }
3428 }
3429
3430 __syncthreads();
3431
3432 vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
3433
3434 __syncthreads();
3435 }
3436
3437 if (fixup) {
3438 write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
3439 } else {
3440 write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
3441 }
3442}
3443
3444
3445// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
3446
3447template <ggml_type type, int mmq_x, bool need_check>
3448#if defined(GGML_USE_HIP)
3449#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3450 __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
3451#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3452#else
3453#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3454 __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
3455#else
3456 __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
3457#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3458#endif // defined(GGML_USE_HIP)
3459static __global__ void mul_mat_q(
3460 const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
3461 const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3462 const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3463 const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3464 const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3465 const int ncols_max) {
3466
3467 // Skip unused template specializations for faster compilation:
3468 if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
3469 NO_DEVICE_CODE;
3470 return;
3471 }
3472
3473 constexpr int nwarps = mmq_get_nwarps_device();
3474 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3475
3476 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3477 constexpr int mmq_y = get_mmq_y_device();
3478
3479 const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
3480 const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
3481
3482 // Initialize the ids for writing back data with just the index.
3483 // For regular matrix multiplications this is never changed.
3484 // For MoE the correct indices are loaded from ids_dst.
3485 extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
3486#pragma unroll
3487 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3488 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3489
3490 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3491 break;
3492 }
3493
3494 ids_dst_shared[j] = j;
3495 }
3496 __syncthreads();
3497
3498 // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3499#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3500 {
3501 const int wt = blockIdx.z / nchannels_y;
3502 const int zt = blockIdx.z - wt*nchannels_y;
3503 const int jt = blockIdx.y;
3504 const int it = blockIdx.x;
3505
3506 // Defaults for regular matrix multiplication:
3507 int col_low = 0;
3508 int col_high = ncols_dst;
3509 int col_diff = ncols_dst;
3510 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
3511 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3512
3513 if (ids_dst) {
3514 col_low = expert_bounds[zt + 0];
3515 col_high = expert_bounds[zt + 1];
3516 col_diff = col_high - col_low;
3517
3518 offset_y = 0;
3519 offset_dst = 0;
3520
3521 if (jt*mmq_x >= col_diff) {
3522 return;
3523 }
3524
3525 // __syncthreads(); // There is no previous tile that could cause a race condition.
3526#pragma unroll
3527 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3528 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3529
3530 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3531 break;
3532 }
3533
3534 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3535 }
3536 __syncthreads();
3537 }
3538
3539 offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3540 offset_dst += it*mmq_y;
3541
3542 const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3543 const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3544
3545 const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3546
3547 constexpr bool fixup = false;
3548 mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3549 (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3550 tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
3551 return;
3552 }
3553#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3554
3555 constexpr int ITER_K = get_iter_k(type);
3556
3557 const int64_t blocks_per_ne00 = ncols_x / qk;
3558 constexpr int blocks_per_iter = ITER_K / qk;
3559
3560 // kbc == k block continuous, current index in continuous ijk space.
3561 int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3562 int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3563
3564 kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3565 kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
3566
3567 // kb0 == k index when doing the matrix multiplication for an output tile.
3568 int kb0_start = kbc % blocks_per_ne00;
3569 int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
3570 while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
3571 int tmp = kbc;
3572 const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3573 tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3574 const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3575 tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3576 const int zt = tmp / (ntx*blocks_per_ne00);
3577 tmp -= zt * (ntx*blocks_per_ne00);
3578 const int jt = tmp / blocks_per_ne00;
3579
3580 // Defaults for regular matrix multiplication:
3581 int col_low = 0;
3582 int col_high = ncols_dst;
3583 int col_diff = ncols_dst;
3584 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
3585 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3586
3587 if (ids_dst) {
3588 col_low = expert_bounds[zt + 0];
3589 col_high = expert_bounds[zt + 1];
3590 col_diff = col_high - col_low;
3591
3592 offset_y = 0;
3593 offset_dst = 0;
3594
3595 if (jt*mmq_x >= col_diff) {
3596 kbc += blocks_per_ne00;
3597 kbc -= kbc % blocks_per_ne00;
3598
3599 kb0_start = 0;
3600 kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
3601
3602 continue;
3603 }
3604
3605 __syncthreads();
3606#pragma unroll
3607 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3608 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3609
3610 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3611 break;
3612 }
3613
3614 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3615 }
3616 __syncthreads();
3617 }
3618
3619 offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3620 offset_dst += it*mmq_y;
3621
3622 const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3623 const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3624
3625 const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3626
3627 constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
3628 mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3629 (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3630 tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3631
3632 kbc += blocks_per_ne00;
3633 kbc -= kbc % blocks_per_ne00;
3634
3635 kb0_start = 0;
3636 kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
3637 }
3638
3639 if (kbc >= kbc_stop) {
3640 return;
3641 }
3642
3643 int tmp = kbc;
3644 const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3645 tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3646 const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3647 tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3648 const int zt = tmp / (ntx*blocks_per_ne00);
3649 tmp -= zt * (ntx*blocks_per_ne00);
3650 const int jt = tmp / blocks_per_ne00;
3651
3652 // Defaults for regular matrix multiplication:
3653 int col_low = 0;
3654 int col_high = ncols_dst;
3655 int col_diff = ncols_dst;
3656 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
3657 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3658
3659 if (ids_dst) {
3660 col_low = expert_bounds[zt + 0];
3661 col_high = expert_bounds[zt + 1];
3662 col_diff = col_high - col_low;
3663
3664 offset_y = 0;
3665 offset_dst = 0;
3666
3667 if (jt*mmq_x >= col_diff) {
3668 return;
3669 }
3670
3671 // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
3672 __syncthreads();
3673#pragma unroll
3674 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3675 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3676
3677 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3678 break;
3679 }
3680
3681 ids_dst_shared[j] = j;
3682 }
3683 __syncthreads();
3684 }
3685
3686 offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3687 offset_dst += it*mmq_y;
3688
3689 const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3690 const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3691
3692 const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3693
3694 constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
3695 mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3696 (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3697 tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3698}
3699
3700template <ggml_type type, int mmq_x, bool need_check>
3701static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3702 const int32_t * expert_bounds,
3703 float * __restrict__ dst,
3704 const float * __restrict__ tmp_last_tile,
3705 const int ncols_x,
3706 const int nrows_x,
3707 const int ncols_dst,
3708 const size_t stride_col_dst,
3709 const int nchannels_y,
3710 const size_t stride_channel_dst,
3711 const int nsamples_y,
3712 const size_t stride_sample_dst,
3713 const int ncols_max) {
3714 constexpr int mmq_y = get_mmq_y_device();
3715 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3716 constexpr int ITER_K = get_iter_k(type);
3717
3718 constexpr int blocks_per_iter = ITER_K / qk;
3719 const int64_t blocks_per_ne00 = ncols_x / qk;
3720
3721 constexpr int nwarps = mmq_get_nwarps_device();
3722 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3723
3724 float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3725
3726 const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
3727 const int nty = (nrows_x + mmq_y - 1) / mmq_y;
3728
3729 const int bidx0 = blockIdx.x;
3730
3731 // kbc == k block continuous, current index in continuous ijk space.
3732 int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3733 int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3734
3735 kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
3736 kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
3737
3738 const bool did_not_have_any_data = kbc0 == kbc0_stop;
3739 const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
3740 const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
3741 if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
3742 return;
3743 }
3744
3745 bool any_fixup = false;
3746
3747 // Iterate over previous blocks and sum up partial sums written to fixup buffer.
3748 // All CUDA blocks that get here must have a previous block that needs a fixup.
3749 int64_t bidx = bidx0 - 1;
3750 int64_t kbc_stop = kbc0;
3751 while(true) {
3752 int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3753 kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3754
3755 if (kbc == kbc_stop) { // Did not have any data.
3756 bidx--;
3757 kbc_stop = kbc;
3758 continue;
3759 }
3760
3761 any_fixup = true;
3762
3763#pragma unroll
3764 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3765 const int j = j0 + threadIdx.y;
3766
3767#pragma unroll
3768 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3769 const int i = i0 + threadIdx.x;
3770
3771 sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3772 }
3773 }
3774
3775 // If this block started in a previous tile we are done and don't need to combine additional partial results.
3776 if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
3777 break;
3778 }
3779 bidx--;
3780 kbc_stop = kbc;
3781 }
3782
3783 if (!any_fixup) {
3784 return;
3785 }
3786
3787 int tmp = kbc0;
3788 const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3789 tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3790 const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3791 tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3792 const int zt = tmp / (ntx*blocks_per_ne00);
3793 tmp -= zt * (ntx*blocks_per_ne00);
3794 const int jt = tmp / blocks_per_ne00;
3795
3796 if (!ids_dst) {
3797 const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
3798 dst += offset_dst;
3799
3800 const int i_max = nrows_x - it*mmq_y - 1;
3801 const int j_max = ncols_dst - jt*mmq_x - 1;
3802
3803#pragma unroll
3804 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3805 const int j = j0 + threadIdx.y;
3806
3807 if (j > j_max) {
3808 return;
3809 }
3810
3811#pragma unroll
3812 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3813 const int i = i0 + threadIdx.x;
3814
3815 if (need_check && i > i_max) {
3816 continue;
3817 }
3818
3819 dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3820 }
3821 }
3822 return;
3823 }
3824
3825 __shared__ int ids_dst_shared[mmq_x];
3826 const int col_low = expert_bounds[zt + 0];
3827 const int col_high = expert_bounds[zt + 1];
3828 const int col_diff = col_high - col_low;
3829
3830 for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3831 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3832 }
3833 __syncthreads();
3834
3835 const int offset_dst = it*mmq_y;
3836 dst += offset_dst;
3837
3838 const int i_max = nrows_x - it*mmq_y - 1;
3839 const int j_max = col_diff - jt*mmq_x - 1;
3840
3841#pragma unroll
3842 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3843 const int j = j0 + threadIdx.y;
3844
3845 if (j > j_max) {
3846 return;
3847 }
3848
3849#pragma unroll
3850 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3851 const int i = i0 + threadIdx.x;
3852
3853 if (need_check && i > i_max) {
3854 continue;
3855 }
3856
3857 dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3858 }
3859 }
3860}
3861
3862struct mmq_args {
3863 const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
3864 int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
3865 int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
3866 int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
3867 bool use_stream_k; int64_t ncols_max;
3868};
3869
3870template<ggml_type type>
3871static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
3872 const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3873 const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3874 const size_t nbs_ids = mmq_x*sizeof(int);
3875 const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3876 const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
3877 return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3878}
3879
3880template <ggml_type type, int mmq_x>
3881static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3882 const int id = ggml_cuda_get_device();
3883 const int cc = ggml_cuda_info().devices[id].cc;
3884 const int nsm = ggml_cuda_info().devices[id].nsm;
3885 const int warp_size = ggml_cuda_info().devices[id].warp_size;
3886 const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3887 const int mmq_y = get_mmq_y_host(cc);
3888
3889 const dim3 block_dims(warp_size, nwarps, 1);
3890
3891 const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
3892
3893 CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
3894 CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
3895
3896 const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3897 const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
3898 const int ntzw = args.nchannels_y * args.nsamples_y;
3899 const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3900
3901 GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
3902 GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
3903 const int channel_ratio = args.nchannels_y / args.nchannels_x;
3904 const int sample_ratio = args.nsamples_y / args.nsamples_x;
3905
3906 if (!args.use_stream_k) {
3907 if (args.nrows_x % mmq_y == 0) {
3908 constexpr bool need_check = false;
3909 mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3910 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3911 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3912 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3913 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3914 args.ncols_max);
3915 } else {
3916 constexpr bool need_check = true;
3917 mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3918 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3919 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3920 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3921 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3922 args.ncols_max);
3923 }
3924 return;
3925 }
3926
3927 const dim3 block_nums_stream_k(nsm, 1, 1);
3928 const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
3929
3930 ggml_cuda_pool & pool = ctx.pool(id);
3931 ggml_cuda_pool_alloc<float> tmp_fixup(pool);
3932 if (fixup_needed) {
3933 tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
3934 }
3935
3936 if (args.nrows_x % mmq_y == 0) {
3937 constexpr bool need_check = false;
3938 mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3939 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3940 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3941 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3942 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3943 args.ncols_max);
3944
3945 if (!fixup_needed) {
3946 return;
3947 }
3948
3949 mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3950 (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3951 args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3952 args.ncols_max);
3953 } else {
3954 constexpr bool need_check = true;
3955 mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3956 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3957 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3958 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3959 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3960 args.ncols_max);
3961
3962 if (!fixup_needed) {
3963 return;
3964 }
3965
3966 mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3967 (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3968 args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3969 args.ncols_max);
3970 }
3971}
3972
3973template <ggml_type type>
3974void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3975 const int id = ggml_cuda_get_device();
3976 const int cc = ggml_cuda_info().devices[id].cc;
3977 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3978 const int warp_size = ggml_cuda_info().devices[id].warp_size;
3979 const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3980
3981 const int mmq_x_max = get_mmq_x_max_host(cc);
3982 const int mmq_y = get_mmq_y_host(cc);
3983
3984 int mmq_x_best = 0;
3985 int ntiles_x_best = INT_MAX;
3986
3987 for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3988 const int granularity = mmq_get_granularity_host(mmq_x, cc);
3989
3990 if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
3991 continue;
3992 }
3993
3994 const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
3995
3996 if (ntiles_x < ntiles_x_best) {
3997 mmq_x_best = mmq_x;
3998 ntiles_x_best = ntiles_x;
3999 }
4000 }
4001
4002 switch (mmq_x_best) {
4003 case 8:
4004 launch_mul_mat_q<type, 8>(ctx, args, stream);
4005 break;
4006 case 16:
4007 launch_mul_mat_q<type, 16>(ctx, args, stream);
4008 break;
4009 case 24:
4010 launch_mul_mat_q<type, 24>(ctx, args, stream);
4011 break;
4012 case 32:
4013 launch_mul_mat_q<type, 32>(ctx, args, stream);
4014 break;
4015 case 40:
4016 launch_mul_mat_q<type, 40>(ctx, args, stream);
4017 break;
4018 case 48:
4019 launch_mul_mat_q<type, 48>(ctx, args, stream);
4020 break;
4021 case 56:
4022 launch_mul_mat_q<type, 56>(ctx, args, stream);
4023 break;
4024 case 64:
4025 launch_mul_mat_q<type, 64>(ctx, args, stream);
4026 break;
4027 case 72:
4028 launch_mul_mat_q<type, 72>(ctx, args, stream);
4029 break;
4030 case 80:
4031 launch_mul_mat_q<type, 80>(ctx, args, stream);
4032 break;
4033 case 88:
4034 launch_mul_mat_q<type, 88>(ctx, args, stream);
4035 break;
4036 case 96:
4037 launch_mul_mat_q<type, 96>(ctx, args, stream);
4038 break;
4039 case 104:
4040 launch_mul_mat_q<type, 104>(ctx, args, stream);
4041 break;
4042 case 112:
4043 launch_mul_mat_q<type, 112>(ctx, args, stream);
4044 break;
4045 case 120:
4046 launch_mul_mat_q<type, 120>(ctx, args, stream);
4047 break;
4048 case 128:
4049 launch_mul_mat_q<type, 128>(ctx, args, stream);
4050 break;
4051 default:
4052 fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
4053 GGML_ABORT("fatal error");
4054 break;
4055 }
4056}
4057
4058#define DECL_MMQ_CASE(type) \
4059 template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
4060
4061extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
4062extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
4063extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
4064extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
4065extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
4066extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
4067extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
4068extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
4069extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
4070extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
4071extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
4072extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
4073extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
4074extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
4075extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
4076extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
4077extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
4078extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
4079extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
4080
4081// -------------------------------------------------------------------------------------------------------------------------
4082
4083void ggml_cuda_mul_mat_q(
4084 ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
4085
4086void ggml_cuda_op_mul_mat_q(
4087 ggml_backend_cuda_context & ctx,
4088 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
4089 const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
4090 const int64_t src1_padded_row_size, cudaStream_t stream);
4091
4092bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);