1#include "mmvq.cuh"
2#include "quantize.cuh"
3#include "unary.cuh"
4#include "vecdotq.cuh"
5
6#include <cstdint>
7
8typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
9
10static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
11 switch (type) {
12 case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
13 case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
14 case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
15 case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
16 case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
17 case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
18 case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
19 case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
20 case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
21 case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
22 case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
23 case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
24 case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
25 case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
26 case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
27 case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
28 case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
29 case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
30 case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
31 case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
32 default: return nullptr;
33 }
34}
35
36static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
37 switch (type) {
38 case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
39 case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
40 case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
41 case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
42 case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
43 case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
44 case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
45 case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
46 case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
47 case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
48 case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
49 case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
50 case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
51 case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
52 case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
53 case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
54 case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
55 case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
56 default: return 1;
57 }
58}
59
60enum mmvq_parameter_table_id {
61 MMVQ_PARAMETERS_GENERIC = 0,
62 MMVQ_PARAMETERS_GCN,
63 MMVQ_PARAMETERS_RDNA2
64};
65
66static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
67#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
68 return MMVQ_PARAMETERS_RDNA2;
69#elif defined(GCN) || defined(CDNA)
70 return MMVQ_PARAMETERS_GCN;
71#else
72 return MMVQ_PARAMETERS_GENERIC;
73#endif
74}
75
76static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
77 if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
78 return MMVQ_PARAMETERS_RDNA2;
79 }
80 if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
81 return MMVQ_PARAMETERS_GCN;
82 }
83 return MMVQ_PARAMETERS_GENERIC;
84}
85
86static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
87 if (table_id == MMVQ_PARAMETERS_GENERIC) {
88 switch (ncols_dst) {
89 case 1:
90 case 2:
91 case 3:
92 case 4:
93 return 4;
94 case 5:
95 case 6:
96 case 7:
97 case 8:
98 return 2;
99 default:
100 return 1;
101 }
102 } else if (table_id == MMVQ_PARAMETERS_GCN) {
103 switch (ncols_dst) {
104 case 1:
105 case 2:
106 case 3:
107 case 4:
108 return 2;
109 case 5:
110 case 6:
111 case 7:
112 case 8:
113 default:
114 return 1;
115 }
116 }
117 return 1;
118}
119
120static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
121 if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
122 switch (ncols_dst) {
123 case 1:
124 return 1;
125 case 2:
126 case 3:
127 case 4:
128 case 5:
129 case 6:
130 case 7:
131 case 8:
132 return 2;
133 default:
134 return 1;
135 }
136 }
137 return 1;
138}
139
140template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
141__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142static __global__ void mul_mat_vec_q(
143 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
144 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145 const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146 const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147 const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
148 const uint32_t ids_stride) {
149
150 constexpr int qk = ggml_cuda_type_traits<type>::qk;
151 constexpr int qi = ggml_cuda_type_traits<type>::qi;
152 constexpr int vdr = get_vdr_mmvq(type);
153 constexpr mmvq_parameter_table_id table_id = get_device_table_id();
154 constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
155 constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
156 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
157
158 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
159
160 const int tid = warp_size*threadIdx.y + threadIdx.x;
161 const int row0 = rows_per_cuda_block*blockIdx.x;
162 const int blocks_per_row_x = ncols_x / qk;
163 constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
164
165 const uint32_t channel_dst = blockIdx.y;
166
167 uint32_t token_idx = 0;
168 uint32_t channel_x;
169 uint32_t channel_y;
170 uint32_t sample_dst;
171
172 if constexpr (is_multi_token_id) {
173 // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
174 token_idx = blockIdx.z;
175 channel_x = ids[channel_dst + token_idx * ids_stride];
176 channel_y = fastmodulo(channel_dst, nchannels_y);
177 sample_dst = 0;
178 } else {
179 channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
180 channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
181 sample_dst = blockIdx.z;
182 }
183
184 const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
185 const uint32_t sample_y = sample_dst;
186
187 bool use_gate = false;
188 bool use_bias = false;
189 bool use_gate_bias = false;
190 const void * vgate = nullptr;
191 const float * x_bias = nullptr;
192 const float * gate_bias = nullptr;
193 ggml_glu_op active_glu;
194
195 if constexpr (has_fusion) {
196 use_gate = fusion.gate != nullptr;
197 use_bias = fusion.x_bias != nullptr;
198 use_gate_bias = fusion.gate_bias != nullptr && use_gate;
199 vgate = fusion.gate;
200 x_bias = (const float *) fusion.x_bias;
201 gate_bias = (const float *) fusion.gate_bias;
202 active_glu = fusion.glu_op;
203 }
204
205
206 float x_biases[ncols_dst] = { 0.0f };
207 float gate_biases[ncols_dst] = { 0.0f };
208 if constexpr (has_fusion) {
209 const uint32_t channel_bias = ids ? channel_x : channel_dst;
210 if (use_bias) {
211 x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
212 // 1. Hide latency by prefetching bias and gate here
213 // 2. load only on threads that won't die after partial sum calculation
214 if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
215 (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
216#pragma unroll
217 for (int j = 0; j < ncols_dst; ++j) {
218 x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
219 }
220 }
221 }
222 if (use_gate_bias) {
223 gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
224 if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
225 (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
226#pragma unroll
227 for (int j = 0; j < ncols_dst; ++j) {
228 gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
229 }
230 }
231 }
232 }
233
234 // partial sum for each thread
235 float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
236 float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
237
238 const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
239 if constexpr (is_multi_token_id) {
240 y += token_idx*stride_col_y;
241 }
242 const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
243
244 for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
245 const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
246
247 // x block quant index when casting the quants to int
248 const int kqs = vdr * (tid % (qi/vdr));
249
250#pragma unroll
251 for (int j = 0; j < ncols_dst; ++j) {
252#pragma unroll
253 for (int i = 0; i < rows_per_cuda_block; ++i) {
254 tmp[j][i] += vec_dot_q_cuda(
255 vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
256 if constexpr (has_fusion) {
257 if (use_gate) {
258 tmp_gate[j][i] += vec_dot_q_cuda(
259 vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
260 }
261 }
262 }
263 }
264 }
265
266 __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
267 __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
268 if constexpr (!has_fusion) {
269 (void) tmp_shared_gate;
270 } else if (!use_gate) {
271 (void) tmp_shared_gate;
272 }
273
274 if (threadIdx.y > 0) {
275#pragma unroll
276 for (int j = 0; j < ncols_dst; ++j) {
277#pragma unroll
278 for (int i = 0; i < rows_per_cuda_block; ++i) {
279 tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
280 if constexpr (has_fusion) {
281 if (use_gate) {
282 tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
283 }
284 }
285 }
286 }
287 }
288 __syncthreads();
289 if (threadIdx.y > 0) {
290 return;
291 }
292
293 dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
294
295 if constexpr (is_multi_token_id) {
296 dst += token_idx*stride_col_dst;
297 }
298
299 // sum up partial sums and write back result
300#pragma unroll
301 for (int j = 0; j < ncols_dst; ++j) {
302#pragma unroll
303 for (int i = 0; i < rows_per_cuda_block; ++i) {
304#pragma unroll
305 for (int l = 0; l < nwarps-1; ++l) {
306 tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
307 if constexpr (has_fusion) {
308 if (use_gate) {
309 tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
310 }
311 }
312 }
313 tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
314 if constexpr (has_fusion) {
315 if (use_gate) {
316 tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
317 }
318 }
319 }
320
321 if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
322 float result = tmp[j][threadIdx.x];
323 if constexpr (has_fusion) {
324 if (use_bias) {
325 result += x_biases[j];
326 }
327 if (use_gate) {
328 float gate_value = tmp_gate[j][threadIdx.x];
329 if (use_gate_bias) {
330 gate_value += gate_biases[j];
331 }
332 switch (active_glu) {
333 case GGML_GLU_OP_SWIGLU:
334 result *= ggml_cuda_op_silu_single(gate_value);
335 break;
336 case GGML_GLU_OP_GEGLU:
337 result *= ggml_cuda_op_gelu_single(gate_value);
338 break;
339 case GGML_GLU_OP_SWIGLU_OAI: {
340 result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
341 break;
342 }
343 default:
344 result = result * gate_value;
345 break;
346 }
347 }
348 }
349 dst[j*stride_col_dst + threadIdx.x] = result;
350 }
351 }
352
353 if constexpr (!has_fusion) {
354 GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
355 }
356}
357
358static std::pair<dim3, dim3> calc_launch_params(
359 const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
360 const int warp_size, const mmvq_parameter_table_id table_id) {
361 const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
362 const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
363 const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
364 return {block_nums, block_dims};
365}
366
367template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
368static void mul_mat_vec_q_switch_fusion(
369 const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
370 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
371 const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
372 const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
373 const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
374 const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
375 const uint32_t ids_stride, cudaStream_t stream) {
376
377 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
378 if constexpr (c_ncols_dst == 1) {
379 if (has_fusion) {
380 mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
381 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
382 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
383 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
384 return;
385 }
386 }
387
388 GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
389
390 mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
391 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
392 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
393 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
394}
395
396template <ggml_type type>
397static void mul_mat_vec_q_switch_ncols_dst(
398 const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
399 const int ncols_x, const int nrows_x, const int ncols_dst,
400 const int stride_row_x, const int stride_col_y, const int stride_col_dst,
401 const int nchannels_x, const int nchannels_y, const int nchannels_dst,
402 const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
403 const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
404 const int ids_stride, cudaStream_t stream) {
405
406 GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
407 GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
408
409 const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
410 const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
411 const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
412
413 const int device = ggml_cuda_get_device();
414 const int warp_size = ggml_cuda_info().devices[device].warp_size;
415 const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
416
417 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
418 const bool has_ids = ids != nullptr;
419
420 if (has_ids && ncols_dst > 1) {
421 // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
422 constexpr int c_ncols_dst = 1;
423 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
424 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
425 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
426 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
427 dims.first, dims.second, 0, ids_stride, stream);
428 return;
429 }
430
431 switch (ncols_dst) {
432 case 1: {
433 constexpr int c_ncols_dst = 1;
434 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
435 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
436 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
438 dims.first, dims.second, 0, ids_stride, stream);
439 } break;
440 case 2: {
441 constexpr int c_ncols_dst = 2;
442 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
443 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
444 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
445 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
446 dims.first, dims.second, 0, ids_stride, stream);
447 } break;
448 case 3: {
449 constexpr int c_ncols_dst = 3;
450 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
451 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
452 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
453 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
454 dims.first, dims.second, 0, ids_stride, stream);
455 } break;
456 case 4: {
457 constexpr int c_ncols_dst = 4;
458 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
459 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
460 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
462 dims.first, dims.second, 0, ids_stride, stream);
463 } break;
464 case 5: {
465 constexpr int c_ncols_dst = 5;
466 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
467 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
468 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
469 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
470 dims.first, dims.second, 0, ids_stride, stream);
471 } break;
472 case 6: {
473 constexpr int c_ncols_dst = 6;
474 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
475 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
476 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
477 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
478 dims.first, dims.second, 0, ids_stride, stream);
479 } break;
480 case 7: {
481 constexpr int c_ncols_dst = 7;
482 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
483 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
486 dims.first, dims.second, 0, ids_stride, stream);
487 } break;
488 case 8: {
489 constexpr int c_ncols_dst = 8;
490 std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
492 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
493 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
494 dims.first, dims.second, 0, ids_stride, stream);
495 } break;
496 default:
497 GGML_ABORT("fatal error");
498 break;
499 }
500
501 GGML_UNUSED(has_fusion);
502}
503static void mul_mat_vec_q_switch_type(
504 const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
505 const int ncols_x, const int nrows_x, const int ncols_dst,
506 const int stride_row_x, const int stride_col_y, const int stride_col_dst,
507 const int nchannels_x, const int nchannels_y, const int nchannels_dst,
508 const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
509 const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
510 const int ids_stride, cudaStream_t stream) {
511 switch (type_x) {
512 case GGML_TYPE_Q4_0:
513 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
514 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
515 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
516 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
517 break;
518 case GGML_TYPE_Q4_1:
519 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
520 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
521 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
522 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
523 break;
524 case GGML_TYPE_Q5_0:
525 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
526 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
527 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
528 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
529 break;
530 case GGML_TYPE_Q5_1:
531 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
532 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
533 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
534 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
535 break;
536 case GGML_TYPE_Q8_0:
537 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
538 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
539 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
540 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
541 break;
542 case GGML_TYPE_MXFP4:
543 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
544 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
545 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
546 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
547 break;
548 case GGML_TYPE_Q2_K:
549 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
550 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
551 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
552 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
553 break;
554 case GGML_TYPE_Q3_K:
555 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
556 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
557 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
558 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
559 break;
560 case GGML_TYPE_Q4_K:
561 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
562 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
563 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
564 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
565 break;
566 case GGML_TYPE_Q5_K:
567 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
568 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
569 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
570 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
571 break;
572 case GGML_TYPE_Q6_K:
573 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
574 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
575 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
576 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
577 break;
578 case GGML_TYPE_IQ2_XXS:
579 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
580 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
581 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
582 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
583 break;
584 case GGML_TYPE_IQ2_XS:
585 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
586 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
587 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
588 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
589 break;
590 case GGML_TYPE_IQ2_S:
591 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
592 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
593 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
594 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
595 break;
596 case GGML_TYPE_IQ3_XXS:
597 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
598 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
599 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
600 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
601 break;
602 case GGML_TYPE_IQ1_S:
603 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
604 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
605 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
606 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
607 break;
608 case GGML_TYPE_IQ1_M:
609 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
610 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
611 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
612 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
613 break;
614 case GGML_TYPE_IQ4_NL:
615 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
616 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
617 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
618 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
619 break;
620 case GGML_TYPE_IQ4_XS:
621 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
622 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
623 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
624 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
625 break;
626 case GGML_TYPE_IQ3_S:
627 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
628 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
629 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
630 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
631 break;
632 default:
633 GGML_ABORT("fatal error");
634 break;
635 }
636}
637
638void ggml_cuda_mul_mat_vec_q(
639 ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
640 const ggml_cuda_mm_fusion_args_host * fusion) {
641 GGML_ASSERT( src1->type == GGML_TYPE_F32);
642 GGML_ASSERT( dst->type == GGML_TYPE_F32);
643 GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
644
645 GGML_TENSOR_BINARY_OP_LOCALS;
646
647 cudaStream_t stream = ctx.stream();
648
649 const size_t ts_src0 = ggml_type_size(src0->type);
650 const size_t ts_src1 = ggml_type_size(src1->type);
651 const size_t ts_dst = ggml_type_size(dst->type);
652
653 GGML_ASSERT( nb00 == ts_src0);
654 GGML_ASSERT( nb10 == ts_src1);
655 GGML_ASSERT( nb0 == ts_dst);
656 GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
657
658 GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
659
660 const float * src1_d = (const float *) src1->data;
661 const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
662 float * dst_d = (float *) dst->data;
663
664 ggml_cuda_mm_fusion_args_device fusion_local{};
665
666 if (fusion) {
667 GGML_ASSERT( !ids || dst->ne[2] == 1);
668 GGML_ASSERT( ids || dst->ne[1] == 1);
669
670 if (fusion->x_bias) {
671 GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
672 GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
673 GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
674 fusion_local.x_bias = fusion->x_bias->data;
675 }
676 if (fusion->gate) {
677 GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
678 fusion_local.gate = fusion->gate->data;
679 }
680 if (fusion->gate_bias) {
681 GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
682 GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
683 GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
684 fusion_local.gate_bias = fusion->gate_bias->data;
685 }
686 fusion_local.glu_op = fusion->glu_op;
687 }
688
689 // If src0 is a temporary compute buffer, clear any potential padding.
690 if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
691 const size_t size_data = ggml_nbytes(src0);
692 const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
693 if (size_alloc > size_data) {
694 GGML_ASSERT(ggml_is_contiguously_allocated(src0));
695 GGML_ASSERT(!src0->view_src);
696 CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
697 }
698 }
699
700 const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
701 ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
702 {
703 const int64_t s11 = src1->nb[1] / ts_src1;
704 const int64_t s12 = src1->nb[2] / ts_src1;
705 const int64_t s13 = src1->nb[3] / ts_src1;
706 quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
707 }
708
709 const int64_t s01 = src0->nb[1] / ts_src0;
710 const int64_t s11 = ne10_padded / QK8_1;
711 const int64_t s1 = dst->nb[1] / ts_dst;
712 const int64_t s02 = src0->nb[2] / ts_src0;
713 const int64_t s2 = dst->nb[2] / ts_dst;
714 const int64_t s03 = src0->nb[3] / ts_src0;
715 const int64_t s3 = dst->nb[3] / ts_dst;
716
717 const int64_t s12 = ne11*s11;
718 const int64_t s13 = ne12*s12;
719
720 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
721 const int64_t ncols_dst = ids ? ne2 : ne1;
722 const int64_t nchannels_y = ids ? ne11 : ne12;
723 const int64_t nchannels_dst = ids ? ne1 : ne2;
724 const int64_t stride_col_dst = ids ? s2 : s1;
725 const int64_t stride_col_y = ids ? s12 : s11;
726 const int64_t stride_channel_dst = ids ? s1 : s2;
727 const int64_t stride_channel_y = ids ? s11 : s12;
728
729 const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
730
731 mul_mat_vec_q_switch_type(
732 src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
733 ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
734 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
735 ne03, ne3, s03, s13, s3, ids_stride, stream);
736}
737
738void ggml_cuda_op_mul_mat_vec_q(
739 ggml_backend_cuda_context & ctx,
740 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
741 const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
742 const int64_t src1_padded_row_size, cudaStream_t stream) {
743
744 const int64_t ne00 = src0->ne[0];
745 const int64_t row_diff = row_high - row_low;
746
747 const int64_t ne10 = src1->ne[0];
748 GGML_ASSERT(ne10 % QK8_1 == 0);
749
750 const int64_t ne0 = dst->ne[0];
751
752 int id = ggml_cuda_get_device();
753
754 // the main device has a larger memory buffer to hold the results from all GPUs
755 // nrows_dst == nrows of the matrix that the kernel writes into
756 const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
757
758 const int stride_row_x = ne00 / ggml_blck_size(src0->type);
759 const int stride_col_y = src1_padded_row_size / QK8_1;
760
761 ggml_cuda_mm_fusion_args_device fusion_local{};
762 mul_mat_vec_q_switch_type(
763 src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
764 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
765
766 GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
767}