1#include "ggml.h"
2#include "common.cuh"
3#include "unary.cuh"
4#include "mmvf.cuh"
5#include "convert.cuh"
6
7template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
8static __global__ void mul_mat_vec_f(
9 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
10 const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
11 const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12 const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
13 const int ids_stride) {
14 const int row = blockIdx.x;
15 // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
16 const int channel_dst = blockIdx.y;
17 const int tid = threadIdx.x;
18
19 int token_idx;
20 int channel_x;
21 int channel_y;
22 int sample_dst;
23
24 if constexpr (is_multi_token_id) {
25 // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
26 token_idx = blockIdx.z;
27 channel_x = ids[channel_dst + token_idx * ids_stride];
28 channel_y = fastmodulo(channel_dst, nchannels_y);
29 sample_dst = 0;
30 } else {
31 token_idx = ids ? blockIdx.z : 0;
32 channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
33 channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
34 sample_dst = ids ? 0 : blockIdx.z;
35 }
36
37 const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
38 const int sample_y = sample_dst;
39
40 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
41
42 x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
43 y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
44 dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
45 if constexpr (is_multi_token_id) {
46 y += token_idx*stride_col_y2*2;
47 dst += token_idx*stride_col_dst;
48 }
49
50 bool use_gate = false;
51 bool use_bias = false;
52 bool use_gate_bias = false;
53 ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
54 const T * gate_x = nullptr;
55 const float * x_bias = nullptr;
56 const float * gate_bias = nullptr;
57
58 if constexpr (has_fusion) {
59 use_gate = fusion.gate != nullptr;
60 use_bias = fusion.x_bias != nullptr;
61 use_gate_bias = fusion.gate_bias != nullptr;
62 glu_op = fusion.glu_op;
63
64 if (use_gate) {
65 gate_x = static_cast<const T *>(fusion.gate);
66 }
67 if (use_bias) {
68 x_bias = static_cast<const float *>(fusion.x_bias);
69 }
70 if (use_gate_bias) {
71 gate_bias = static_cast<const float *>(fusion.gate_bias);
72 use_gate_bias = use_gate;
73 } else {
74 use_gate_bias = false;
75 }
76 }
77
78 if (use_gate) {
79 gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
80 }
81
82 const int channel_bias = ids ? channel_x : channel_dst;
83
84 if constexpr (has_fusion) {
85 if (use_bias) {
86 x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
87 }
88 if (use_gate_bias) {
89 gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
90 }
91 }
92
93 const float2 * y2 = (const float2 *) y;
94
95 extern __shared__ char data_mmv[];
96 float * buf_iw = (float *) data_mmv;
97 float * buf_iw_gate = nullptr;
98 if constexpr (has_fusion) {
99 buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
100 }
101
102 if (block_size > warp_size) {
103 if (tid < warp_size) {
104 buf_iw[tid] = 0.0f;
105 if constexpr (has_fusion) {
106 if (use_gate) {
107 buf_iw_gate[tid] = 0.0f;
108 }
109 }
110 }
111 __syncthreads();
112 }
113
114 float sumf[ncols_dst] = {0.0f};
115 float sumf_gate[ncols_dst];
116 if constexpr (has_fusion) {
117#pragma unroll
118 for (int j = 0; j < ncols_dst; ++j) {
119 sumf_gate[j] = 0.0f;
120 }
121 }
122
123 if constexpr (std::is_same_v<T, float>) {
124 const float2 * x2 = (const float2 *) x;
125 const float2 * gate_x2 = nullptr;
126 if constexpr (has_fusion) {
127 if (use_gate) {
128 gate_x2 = (const float2 *) gate_x;
129 }
130 }
131
132 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
133 const float2 tmpx = x2[col2];
134 float2 tmpx_gate = make_float2(0.0f, 0.0f);
135 if constexpr (has_fusion) {
136 if (use_gate) {
137 tmpx_gate = gate_x2[col2];
138 }
139 }
140
141#pragma unroll
142 for (int j = 0; j < ncols_dst; ++j) {
143 const float2 tmpy = y2[j*stride_col_y2 + col2];
144 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
145 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
146
147 if constexpr (has_fusion) {
148 if (use_gate) {
149 ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
150 ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
151 }
152 }
153 }
154 }
155 } else if constexpr (std::is_same_v<T, half>) {
156 const half2 * x2 = (const half2 *) x;
157 const half2 * gate_x2 = nullptr;
158 if constexpr (has_fusion) {
159 if (use_gate) {
160 gate_x2 = (const half2 *) gate_x;
161 }
162 }
163
164 if (std::is_same_v<type_acc, float>) {
165 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
166 const float2 tmpx = __half22float2(x2[col2]);
167 float2 tmpx_gate = make_float2(0.0f, 0.0f);
168 if constexpr (has_fusion) {
169 if (use_gate) {
170 tmpx_gate = __half22float2(gate_x2[col2]);
171 }
172 }
173#pragma unroll
174 for (int j = 0; j < ncols_dst; ++j) {
175 const float2 tmpy = y2[j*stride_col_y2 + col2];
176 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
177 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
178
179 if constexpr (has_fusion) {
180 if (use_gate) {
181 ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
182 ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
183 }
184 }
185 }
186 }
187 } else {
188#ifdef FP16_AVAILABLE
189 half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
190 half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
191
192 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
193 const half2 tmpx = x2[col2];
194 half2 tmpx_gate = make_half2(0.0f, 0.0f);
195 if constexpr (has_fusion) {
196 if (use_gate) {
197 tmpx_gate = gate_x2[col2];
198 }
199 }
200#pragma unroll
201 for (int j = 0; j < ncols_dst; ++j) {
202 const float2 tmpy = y2[j*stride_col_y2 + col2];
203 sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
204
205 if constexpr (has_fusion) {
206 if (use_gate) {
207 sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
208 }
209 }
210 }
211 }
212
213#pragma unroll
214 for (int j = 0; j < ncols_dst; ++j) {
215 sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
216 }
217
218 if constexpr (has_fusion) {
219 if (use_gate) {
220#pragma unroll
221 for (int j = 0; j < ncols_dst; ++j) {
222 sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
223 }
224 }
225 }
226#else
227 NO_DEVICE_CODE;
228#endif // FP16_AVAILABLE
229 }
230 } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
231//TODO: add support for ggml_cuda_mad for hip_bfloat162
232#if defined(GGML_USE_HIP)
233 const int * x2 = (const int *) x;
234 const int * gate_x2 = nullptr;
235 if constexpr (has_fusion) {
236 if (use_gate) {
237 gate_x2 = (const int *) gate_x;
238 }
239 }
240 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
241 const int tmpx = x2[col2];
242 int tmpx_gate = 0;
243 if constexpr (has_fusion) {
244 if (use_gate) {
245 tmpx_gate = gate_x2[col2];
246 }
247 }
248#pragma unroll
249 for (int j = 0; j < ncols_dst; ++j) {
250 const float2 tmpy = y2[j*stride_col_y2 + col2];
251 const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
252 const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
253 ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
254 ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
255
256 if constexpr (has_fusion) {
257 if (use_gate) {
258 const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
259 const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
260 ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
261 ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
262 }
263 }
264 }
265 }
266#else
267 const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
268 const nv_bfloat162 * gate_x2 = nullptr;
269 if constexpr (has_fusion) {
270 if (use_gate) {
271 gate_x2 = (const nv_bfloat162 *) gate_x;
272 }
273 }
274 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
275 const nv_bfloat162 tmpx = x2[col2];
276 nv_bfloat162 tmpx_gate;
277 if constexpr (has_fusion) {
278 if (use_gate) {
279 tmpx_gate = gate_x2[col2];
280 }
281 }
282#pragma unroll
283 for (int j = 0; j < ncols_dst; ++j) {
284 const float2 tmpy = y2[j*stride_col_y2 + col2];
285 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
286 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
287
288 if constexpr (has_fusion) {
289 if (use_gate) {
290 ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
291 ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
292 }
293 }
294 }
295 }
296#endif
297 } else {
298 static_assert(std::is_same_v<T, void>, "unsupported type");
299 }
300
301#pragma unroll
302 for (int j = 0; j < ncols_dst; ++j) {
303 sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
304
305 if constexpr (has_fusion) {
306 if (use_gate) {
307 sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
308 }
309 }
310
311 if (block_size > warp_size) {
312 buf_iw[tid/warp_size] = sumf[j];
313 if constexpr (has_fusion) {
314 if (use_gate) {
315 buf_iw_gate[tid/warp_size] = sumf_gate[j];
316 }
317 }
318 __syncthreads();
319 if (tid < warp_size) {
320 sumf[j] = buf_iw[tid];
321 sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
322 if constexpr (has_fusion) {
323 if (use_gate) {
324 sumf_gate[j] = buf_iw_gate[tid];
325 sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
326 }
327 }
328 }
329
330 if (j < ncols_dst) {
331 __syncthreads();
332 }
333 }
334 }
335
336 if (tid >= ncols_dst) {
337 return;
338 }
339
340 float value = sumf[tid];
341
342 if constexpr (has_fusion) {
343 if (use_bias) {
344 value += x_bias[tid*stride_col_dst + row];
345 }
346
347 if (use_gate) {
348 float gate_value = sumf_gate[tid];
349 if (use_gate_bias) {
350 gate_value += gate_bias[tid*stride_col_dst + row];
351 }
352 switch (glu_op) {
353 case GGML_GLU_OP_SWIGLU:
354 value *= ggml_cuda_op_silu_single(gate_value);
355 break;
356 case GGML_GLU_OP_GEGLU:
357 value *= ggml_cuda_op_gelu_single(gate_value);
358 break;
359 case GGML_GLU_OP_SWIGLU_OAI: {
360 value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
361 break;
362 }
363 default:
364 break;
365 }
366 }
367 }
368
369 dst[tid*stride_col_dst + row] = value;
370
371 if constexpr (!has_fusion) {
372 GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
373 }
374}
375
376template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
377static void mul_mat_vec_f_switch_fusion(
378 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
379 const int64_t ncols, const uint3 nchannels_y,
380 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
381 const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
382 const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
383 const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
384
385 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
386 if constexpr (ncols_dst == 1) {
387 if (has_fusion) {
388 mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
389 (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
390 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
391 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
392 return;
393 }
394 }
395
396 GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
397
398 mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
399 (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
400 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
401 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
402
403}
404
405template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
406void launch_mul_mat_vec_f_cuda(
407 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
408 const int64_t ncols, const int64_t nrows,
409 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
410 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
411 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
412 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
413 const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
414 GGML_ASSERT(ncols % 2 == 0);
415 GGML_ASSERT(stride_row % 2 == 0);
416 GGML_ASSERT(stride_col_y % 2 == 0);
417 GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
418 GGML_ASSERT( nsamples_dst % nsamples_x == 0);
419 const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
420 const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
421 const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
422
423 const int device = ggml_cuda_get_device();
424 const int warp_size = ggml_cuda_info().devices[device].warp_size;
425
426 int64_t block_size_best = warp_size;
427 int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
428 int64_t max_block_size = 256;
429 if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
430 max_block_size = 128;
431 }
432 for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
433 const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
434 if (niter < niter_best) {
435 niter_best = niter;
436 block_size_best = block_size;
437 }
438 }
439
440 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
441
442 const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
443 const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
444 const dim3 block_dims(block_size_best, 1, 1);
445 switch (block_size_best) {
446 case 32: {
447 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
448 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
449 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
450 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
451 } break;
452 case 64: {
453 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
454 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
455 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
456 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
457 } break;
458 case 96: {
459 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
460 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
461 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
462 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
463 } break;
464 case 128: {
465 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
466 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
467 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
468 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
469 } break;
470 case 160: {
471 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
472 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
473 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
474 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
475 } break;
476 case 192: {
477 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
478 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
479 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
480 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
481 } break;
482 case 224: {
483 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
484 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
485 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
486 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
487 } break;
488 case 256: {
489 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
490 (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
491 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
492 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
493 } break;
494 default: {
495 GGML_ABORT("fatal error");
496 } break;
497 }
498}
499
500template <typename T, typename type_acc>
501static void mul_mat_vec_f_cuda_switch_ncols_dst(
502 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
503 const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
504 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
505 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
506 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
507 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
508 const int64_t ids_stride, cudaStream_t stream) {
509
510 const bool has_ids = ids != nullptr;
511
512 if (has_ids && ncols_dst > 1) {
513 // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
514 constexpr int c_ncols_dst = 1;
515 launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
516 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
517 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
518 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
519 ncols_dst, ids_stride, stream);
520 return;
521 }
522
523 if (has_ids) {
524 // Single-token MUL_MAT_ID path
525 constexpr int c_ncols_dst = 1;
526 launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
527 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
528 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
529 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
530 ncols_dst, ids_stride, stream);
531 return;
532 }
533
534 switch (ncols_dst) {
535 case 1:
536 launch_mul_mat_vec_f_cuda<T, type_acc, 1>
537 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
538 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
539 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
540 nsamples_dst, ids_stride, stream);
541 break;
542 case 2:
543 launch_mul_mat_vec_f_cuda<T, type_acc, 2>
544 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
545 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
546 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
547 nsamples_dst, ids_stride, stream);
548 break;
549 case 3:
550 launch_mul_mat_vec_f_cuda<T, type_acc, 3>
551 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
552 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
553 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
554 nsamples_dst, ids_stride, stream);
555 break;
556 case 4:
557 launch_mul_mat_vec_f_cuda<T, type_acc, 4>
558 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
559 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
560 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
561 nsamples_dst, ids_stride, stream);
562 break;
563 case 5:
564 launch_mul_mat_vec_f_cuda<T, type_acc, 5>
565 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
566 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
567 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
568 nsamples_dst, ids_stride, stream);
569 break;
570 case 6:
571 launch_mul_mat_vec_f_cuda<T, type_acc, 6>
572 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
573 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
574 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
575 nsamples_dst, ids_stride, stream);
576 break;
577 case 7:
578 launch_mul_mat_vec_f_cuda<T, type_acc, 7>
579 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
580 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
581 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
582 nsamples_dst, ids_stride, stream);
583 break;
584 case 8:
585 launch_mul_mat_vec_f_cuda<T, type_acc, 8>
586 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
587 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
588 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
589 nsamples_dst, ids_stride, stream);
590 break;
591 default:
592 GGML_ABORT("fatal error");
593 break;
594 }
595}
596
597template<typename T>
598static void mul_mat_vec_f_cuda(
599 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
600 const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
601 const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
602 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
603 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
604 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
605 const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
606
607 if constexpr(std::is_same_v<T, half>) {
608 if (prec == GGML_PREC_DEFAULT) {
609 mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
610 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
611 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
612 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
613 return;
614 }
615 }
616 mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
617 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
618 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
619 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
620}
621
622void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
623 const ggml_cuda_mm_fusion_args_host * fusion) {
624 GGML_ASSERT( src1->type == GGML_TYPE_F32);
625 GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
626 GGML_ASSERT( dst->type == GGML_TYPE_F32);
627
628 GGML_TENSOR_BINARY_OP_LOCALS;
629
630 const size_t ts_src0 = ggml_type_size(src0->type);
631 const size_t ts_src1 = ggml_type_size(src1->type);
632 const size_t ts_dst = ggml_type_size(dst->type);
633
634 GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
635 GGML_ASSERT(ne13 == ne3);
636
637 GGML_ASSERT( nb00 == ts_src0);
638 GGML_ASSERT( nb10 == ts_src1);
639 GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
640 GGML_ASSERT( nb0 == ts_dst);
641
642 const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
643 const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
644
645 const float * src1_d = (const float *) src1->data;
646 const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
647 float * dst_d = (float *) dst->data;
648
649 ggml_cuda_mm_fusion_args_device fusion_local{};
650
651 if (fusion) {
652 GGML_ASSERT( !ids || dst->ne[2] == 1);
653 GGML_ASSERT( ids || dst->ne[1] == 1);
654 if (fusion->x_bias) {
655 GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
656 GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
657 GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
658 fusion_local.x_bias = fusion->x_bias->data;
659 }
660 if (fusion->gate) {
661 GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
662 fusion_local.gate = fusion->gate->data;
663 }
664 if (fusion->gate_bias) {
665 GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
666 GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
667 GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
668 fusion_local.gate_bias = fusion->gate_bias->data;
669 }
670 fusion_local.glu_op = fusion->glu_op;
671 }
672
673 const int64_t s01 = src0->nb[1] / ts_src0;
674 const int64_t s11 = src1->nb[1] / ts_src1;
675 const int64_t s1 = dst->nb[1] / ts_dst;
676 const int64_t s02 = src0->nb[2] / ts_src0;
677 const int64_t s12 = src1->nb[2] / ts_src1;
678 const int64_t s2 = dst->nb[2] / ts_dst;
679 const int64_t s03 = src0->nb[3] / ts_src0;
680 const int64_t s13 = src1->nb[3] / ts_src1;
681 const int64_t s3 = dst->nb[3] / ts_dst;
682
683 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
684 const int64_t ncols_dst = ids ? ne2 : ne1;
685 const int64_t nchannels_y = ids ? ne11 : ne12;
686 const int64_t nchannels_dst = ids ? ne1 : ne2;
687 const int64_t stride_col_dst = ids ? s2 : s1;
688 const int64_t stride_col_y = ids ? s12 : s11;
689 const int64_t stride_channel_dst = ids ? s1 : s2;
690 const int64_t stride_channel_y = ids ? s11 : s12;
691
692 const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
693
694 switch (src0->type) {
695 case GGML_TYPE_F32: {
696 const float * src0_d = (const float *) src0->data;
697 mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
698 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
699 ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
700 } break;
701 case GGML_TYPE_F16: {
702 const half * src0_d = (const half *) src0->data;
703 mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
704 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
705 ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
706 } break;
707 case GGML_TYPE_BF16: {
708 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
709 mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
710 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
711 ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
712 } break;
713 default:
714 GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
715 }
716}
717
718void ggml_cuda_op_mul_mat_vec_f(
719 ggml_backend_cuda_context & ctx,
720 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
721 const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
722 const int64_t src1_padded_row_size, cudaStream_t stream) {
723
724 GGML_ASSERT(src1->type == GGML_TYPE_F32);
725 GGML_ASSERT(dst->type == GGML_TYPE_F32);
726
727 const int64_t ne00 = src0->ne[0];
728 const int64_t ne10 = src1->ne[0];
729 const int64_t ne0 = dst->ne[0];
730 const int64_t row_diff = row_high - row_low;
731
732 const int id = ggml_cuda_get_device();
733 const int cc = ggml_cuda_info().devices[id].cc;
734 const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
735
736 // ggml_cuda_op provides single, contiguous matrices
737 const int64_t stride_row = ne00;
738 const int64_t stride_col_y = ne10;
739 const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
740 const int64_t nchannels_x = 1;
741 const int64_t nchannels_y = 1;
742 const int64_t nchannels_dst = 1;
743 const int64_t stride_channel_x = 0;
744 const int64_t stride_channel_y = 0;
745 const int64_t stride_channel_dst = 0;
746 const int64_t nsamples_x = 1;
747 const int64_t nsamples_dst = 1;
748 const int64_t stride_sample_x = 0;
749 const int64_t stride_sample_y = 0;
750 const int64_t stride_sample_dst = 0;
751
752 ggml_cuda_mm_fusion_args_device empty{};
753 switch (src0->type) {
754 case GGML_TYPE_F32: {
755 const float * src0_d = (const float *) src0_dd_i;
756 mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
757 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
758 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
759 } break;
760 case GGML_TYPE_F16: {
761 const half * src0_d = (const half *) src0_dd_i;
762 mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
763 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
764 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
765 } break;
766 case GGML_TYPE_BF16: {
767 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
768 mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
769 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
770 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
771 } break;
772 default:
773 GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
774 }
775
776 GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
777}
778
779bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
780 if (src0_ne[0] % 2 != 0) {
781 return false;
782 }
783
784 const size_t ts = ggml_type_size(type);
785 if (src0_nb[0] != ts) {
786 return false;
787 }
788
789 // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
790 for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
791 if (src0_nb[i] % (2*ts) != 0) {
792 return false;
793 }
794 }
795
796 switch (type) {
797 case GGML_TYPE_F32:
798 if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
799 if (ampere_mma_available(cc)) {
800 return ne11 <= 3;
801 }
802 if (cc >= GGML_CUDA_CC_TURING) {
803 return ne11 <= 4;
804 }
805 return ne11 <= 3;
806 } else if (GGML_CUDA_CC_IS_AMD(cc)) {
807 if (fp32_mma_hardware_available(cc)) {
808 return ne11 <= 3;
809 }
810 return ne11 <= 8;
811 }
812 return ne11 <= 8;
813 case GGML_TYPE_F16:
814 if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
815 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
816 if (ampere_mma_available(cc)) {
817 return src0_small && ne11 == 1;
818 }
819 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
820 return src0_small && ne11 <= 4;
821 }
822 if (fp16_mma_hardware_available(cc)) {
823 return src0_small && ne11 <= 3;
824 }
825 return ne11 <= 8;
826 } else if (GGML_CUDA_CC_IS_AMD(cc)) {
827 if (fp16_mma_hardware_available(cc)) {
828 if (GGML_CUDA_CC_IS_RDNA3(cc)) {
829 return ne11 <= 3;
830 }
831 if (GGML_CUDA_CC_IS_RDNA4(cc)) {
832 return ne11 <= 5;
833 }
834 return ne11 <= 2;
835 }
836 return ne11 <= 8;
837 }
838 return ne11 <= 8;
839 case GGML_TYPE_BF16:
840 if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
841 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
842 if (ampere_mma_available(cc)) {
843 return src0_small && ne11 == 1;
844 }
845 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
846 return src0_small && ne11 <= 4;
847 }
848 if (bf16_mma_hardware_available(cc)) {
849 return src0_small && ne11 <= 3;
850 }
851 return ne11 <= 8;
852 } else if (GGML_CUDA_CC_IS_AMD(cc)) {
853 if (bf16_mma_hardware_available(cc)) {
854 return ne11 <= 3;
855 }
856 return ne11 <= 8;
857 }
858 return ne11 <= 8;
859 default:
860 return false;
861 }
862}