1#include "ggml.h"
  2#include "common.cuh"
  3#include "unary.cuh"
  4#include "mmvf.cuh"
  5#include "convert.cuh"
  6
  7template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
  8static __global__ void mul_mat_vec_f(
  9        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
 10        const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
 11        const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
 12        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
 13        const int ids_stride) {
 14    const int row         = blockIdx.x;
 15    // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
 16    const int channel_dst = blockIdx.y;
 17    const int tid         = threadIdx.x;
 18
 19    int token_idx;
 20    int channel_x;
 21    int channel_y;
 22    int sample_dst;
 23
 24    if constexpr (is_multi_token_id) {
 25        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
 26        token_idx  = blockIdx.z;
 27        channel_x  = ids[channel_dst + token_idx * ids_stride];
 28        channel_y  = fastmodulo(channel_dst, nchannels_y);
 29        sample_dst = 0;
 30    } else {
 31        token_idx  = ids ? blockIdx.z                                          : 0;
 32        channel_x  = ids ? ids[blockIdx.y + token_idx * ids_stride]            : fastdiv((uint32_t) channel_dst, channel_ratio);
 33        channel_y  = ids ? fastmodulo(blockIdx.y, nchannels_y)                 : channel_dst;
 34        sample_dst = ids ? 0                                                   : blockIdx.z;
 35    }
 36
 37    const int sample_x    = fastdiv((uint32_t) sample_dst, sample_ratio);
 38    const int sample_y    = sample_dst;
 39
 40    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
 41
 42    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
 43    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
 44    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
 45    if constexpr (is_multi_token_id) {
 46        y   += token_idx*stride_col_y2*2;
 47        dst += token_idx*stride_col_dst;
 48    }
 49
 50    bool use_gate = false;
 51    bool use_bias = false;
 52    bool use_gate_bias = false;
 53    ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
 54    const T * gate_x = nullptr;
 55    const float * x_bias = nullptr;
 56    const float * gate_bias = nullptr;
 57
 58    if constexpr (has_fusion) {
 59        use_gate = fusion.gate != nullptr;
 60        use_bias = fusion.x_bias != nullptr;
 61        use_gate_bias = fusion.gate_bias != nullptr;
 62        glu_op = fusion.glu_op;
 63
 64        if (use_gate) {
 65            gate_x = static_cast<const T *>(fusion.gate);
 66        }
 67        if (use_bias) {
 68            x_bias = static_cast<const float *>(fusion.x_bias);
 69        }
 70        if (use_gate_bias) {
 71            gate_bias = static_cast<const float *>(fusion.gate_bias);
 72            use_gate_bias = use_gate;
 73        } else {
 74            use_gate_bias = false;
 75        }
 76    }
 77
 78    if (use_gate) {
 79        gate_x += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
 80    }
 81
 82    const int channel_bias = ids ? channel_x : channel_dst;
 83
 84    if constexpr (has_fusion) {
 85        if (use_bias) {
 86            x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
 87        }
 88        if (use_gate_bias) {
 89            gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
 90        }
 91    }
 92
 93    const float2 * y2 = (const float2 *) y;
 94
 95    extern __shared__ char data_mmv[];
 96    float * buf_iw = (float *) data_mmv;
 97    float * buf_iw_gate = nullptr;
 98    if constexpr (has_fusion) {
 99        buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
100    }
101
102    if (block_size > warp_size) {
103        if (tid < warp_size) {
104            buf_iw[tid] = 0.0f;
105            if constexpr (has_fusion) {
106                if (use_gate) {
107                    buf_iw_gate[tid] = 0.0f;
108                }
109            }
110        }
111        __syncthreads();
112    }
113
114    float sumf[ncols_dst] = {0.0f};
115    float sumf_gate[ncols_dst];
116    if constexpr (has_fusion) {
117#pragma unroll
118        for (int j = 0; j < ncols_dst; ++j) {
119            sumf_gate[j] = 0.0f;
120        }
121    }
122
123    if constexpr (std::is_same_v<T, float>) {
124        const float2 * x2 = (const float2 *) x;
125        const float2 * gate_x2 = nullptr;
126        if constexpr (has_fusion) {
127            if (use_gate) {
128                gate_x2 = (const float2 *) gate_x;
129            }
130        }
131
132        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
133            const float2 tmpx = x2[col2];
134            float2 tmpx_gate = make_float2(0.0f, 0.0f);
135            if constexpr (has_fusion) {
136                if (use_gate) {
137                    tmpx_gate = gate_x2[col2];
138                }
139            }
140
141#pragma unroll
142            for (int j = 0; j < ncols_dst; ++j) {
143                const float2 tmpy = y2[j*stride_col_y2 + col2];
144                ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
145                ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
146
147                if constexpr (has_fusion) {
148                    if (use_gate) {
149                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
150                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
151                    }
152                }
153            }
154        }
155    } else if constexpr (std::is_same_v<T, half>) {
156        const half2 * x2 = (const half2 *) x;
157        const half2 * gate_x2 = nullptr;
158        if constexpr (has_fusion) {
159            if (use_gate) {
160                gate_x2 = (const half2 *) gate_x;
161            }
162        }
163
164        if (std::is_same_v<type_acc, float>) {
165            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
166                const float2 tmpx = __half22float2(x2[col2]);
167                float2 tmpx_gate = make_float2(0.0f, 0.0f);
168                if constexpr (has_fusion) {
169                    if (use_gate) {
170                        tmpx_gate = __half22float2(gate_x2[col2]);
171                    }
172                }
173#pragma unroll
174                for (int j = 0; j < ncols_dst; ++j) {
175                    const float2 tmpy = y2[j*stride_col_y2 + col2];
176                    ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
177                    ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
178
179                    if constexpr (has_fusion) {
180                        if (use_gate) {
181                            ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
182                            ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
183                        }
184                    }
185                }
186            }
187        } else {
188#ifdef FP16_AVAILABLE
189            half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
190            half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
191
192            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
193                const half2 tmpx = x2[col2];
194                half2 tmpx_gate = make_half2(0.0f, 0.0f);
195                if constexpr (has_fusion) {
196                    if (use_gate) {
197                        tmpx_gate = gate_x2[col2];
198                    }
199                }
200#pragma unroll
201                for (int j = 0; j < ncols_dst; ++j) {
202                    const float2 tmpy = y2[j*stride_col_y2 + col2];
203                    sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
204
205                    if constexpr (has_fusion) {
206                        if (use_gate) {
207                            sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
208                        }
209                    }
210                }
211            }
212
213#pragma unroll
214            for (int j = 0; j < ncols_dst; ++j) {
215                sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
216            }
217
218            if constexpr (has_fusion) {
219                if (use_gate) {
220#pragma unroll
221                    for (int j = 0; j < ncols_dst; ++j) {
222                        sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
223                    }
224                }
225            }
226#else
227            NO_DEVICE_CODE;
228#endif // FP16_AVAILABLE
229        }
230    } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
231//TODO: add support for ggml_cuda_mad for hip_bfloat162
232#if defined(GGML_USE_HIP)
233        const int * x2 = (const int *) x;
234        const int * gate_x2 = nullptr;
235        if constexpr (has_fusion) {
236            if (use_gate) {
237                gate_x2 = (const int *) gate_x;
238            }
239        }
240        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
241            const int tmpx = x2[col2];
242            int tmpx_gate = 0;
243            if constexpr (has_fusion) {
244                if (use_gate) {
245                    tmpx_gate = gate_x2[col2];
246                }
247            }
248#pragma unroll
249            for (int j = 0; j < ncols_dst; ++j) {
250                const float2 tmpy = y2[j*stride_col_y2 + col2];
251                const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
252                const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
253                ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
254                ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
255
256                if constexpr (has_fusion) {
257                    if (use_gate) {
258                        const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
259                        const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
260                        ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
261                        ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
262                    }
263                }
264            }
265        }
266#else
267        const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
268        const nv_bfloat162 * gate_x2 = nullptr;
269        if constexpr (has_fusion) {
270            if (use_gate) {
271                gate_x2 = (const nv_bfloat162 *) gate_x;
272            }
273        }
274        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
275            const nv_bfloat162 tmpx = x2[col2];
276            nv_bfloat162 tmpx_gate;
277            if constexpr (has_fusion) {
278                if (use_gate) {
279                    tmpx_gate = gate_x2[col2];
280                }
281            }
282#pragma unroll
283            for (int j = 0; j < ncols_dst; ++j) {
284                const float2 tmpy = y2[j*stride_col_y2 + col2];
285                ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
286                ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
287
288                if constexpr (has_fusion) {
289                    if (use_gate) {
290                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
291                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
292                    }
293                }
294            }
295        }
296#endif
297    } else {
298        static_assert(std::is_same_v<T, void>, "unsupported type");
299    }
300
301#pragma unroll
302    for (int j = 0; j < ncols_dst; ++j) {
303        sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
304
305        if constexpr (has_fusion) {
306            if (use_gate) {
307                sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
308            }
309        }
310
311        if (block_size > warp_size) {
312            buf_iw[tid/warp_size] = sumf[j];
313            if constexpr (has_fusion) {
314                if (use_gate) {
315                    buf_iw_gate[tid/warp_size] = sumf_gate[j];
316                }
317            }
318            __syncthreads();
319            if (tid < warp_size) {
320                sumf[j] = buf_iw[tid];
321                sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
322                if constexpr (has_fusion) {
323                    if (use_gate) {
324                        sumf_gate[j] = buf_iw_gate[tid];
325                        sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
326                    }
327                }
328            }
329
330            if (j < ncols_dst) {
331                __syncthreads();
332            }
333        }
334    }
335
336    if (tid >= ncols_dst) {
337        return;
338    }
339
340    float value = sumf[tid];
341
342    if constexpr (has_fusion) {
343        if (use_bias) {
344            value += x_bias[tid*stride_col_dst + row];
345        }
346
347        if (use_gate) {
348            float gate_value = sumf_gate[tid];
349            if (use_gate_bias) {
350                gate_value += gate_bias[tid*stride_col_dst + row];
351            }
352            switch (glu_op) {
353                case GGML_GLU_OP_SWIGLU:
354                    value *= ggml_cuda_op_silu_single(gate_value);
355                    break;
356                case GGML_GLU_OP_GEGLU:
357                    value *= ggml_cuda_op_gelu_single(gate_value);
358                    break;
359                case GGML_GLU_OP_SWIGLU_OAI: {
360                    value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
361                    break;
362                }
363                default:
364                    break;
365            }
366        }
367    }
368
369    dst[tid*stride_col_dst + row] = value;
370
371    if constexpr (!has_fusion) {
372        GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
373    }
374}
375
376template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
377static void mul_mat_vec_f_switch_fusion(
378        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
379        const int64_t ncols, const uint3 nchannels_y,
380        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
381        const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
382        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
383        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
384
385    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
386    if constexpr (ncols_dst == 1) {
387        if (has_fusion) {
388            mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
389                (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
390                channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
391                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
392            return;
393       }
394    }
395
396    GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
397
398    mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
399        (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
400        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
401        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
402
403}
404
405template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
406void launch_mul_mat_vec_f_cuda(
407        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
408        const int64_t ncols, const int64_t nrows,
409        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
410        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
411        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
412        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
413        const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
414    GGML_ASSERT(ncols        % 2 == 0);
415    GGML_ASSERT(stride_row   % 2 == 0);
416    GGML_ASSERT(stride_col_y % 2 == 0);
417    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
418    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
419    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
420    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
421    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
422
423    const int device = ggml_cuda_get_device();
424    const int warp_size = ggml_cuda_info().devices[device].warp_size;
425
426    int64_t block_size_best = warp_size;
427    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
428    int64_t max_block_size  = 256;
429    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
430        max_block_size = 128;
431    }
432    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
433        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
434        if (niter < niter_best) {
435            niter_best      = niter;
436            block_size_best = block_size;
437        }
438    }
439
440    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
441
442    const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
443    const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
444    const dim3 block_dims(block_size_best, 1, 1);
445    switch (block_size_best) {
446        case   32: {
447            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
448                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
449                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
450                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
451        } break;
452        case   64: {
453            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
454                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
455                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
456                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
457        } break;
458        case   96: {
459            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
460                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
461                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
462                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
463        } break;
464        case  128: {
465            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
466                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
467                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
468                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
469        } break;
470        case  160: {
471            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
472                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
473                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
474                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
475        } break;
476        case  192: {
477            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
478                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
479                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
480                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
481        } break;
482        case  224: {
483            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
484                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
485                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
486                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
487        } break;
488        case  256: {
489            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
490                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
491                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
492                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
493        } break;
494        default: {
495            GGML_ABORT("fatal error");
496        } break;
497    }
498}
499
500template <typename T, typename type_acc>
501static void mul_mat_vec_f_cuda_switch_ncols_dst(
502        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
503        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
504        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
505        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
506        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
507        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
508        const int64_t ids_stride, cudaStream_t stream) {
509
510    const bool has_ids = ids != nullptr;
511
512    if (has_ids && ncols_dst > 1) {
513        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
514        constexpr int c_ncols_dst = 1;
515        launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
516            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
517             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
518             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
519             ncols_dst, ids_stride, stream);
520        return;
521    }
522
523    if (has_ids) {
524        // Single-token MUL_MAT_ID path
525        constexpr int c_ncols_dst = 1;
526        launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
527            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
528             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
529             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
530             ncols_dst, ids_stride, stream);
531        return;
532    }
533
534    switch (ncols_dst) {
535        case 1:
536            launch_mul_mat_vec_f_cuda<T, type_acc, 1>
537                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
538                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
539                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
540                 nsamples_dst, ids_stride, stream);
541            break;
542        case 2:
543            launch_mul_mat_vec_f_cuda<T, type_acc, 2>
544                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
545                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
546                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
547                 nsamples_dst, ids_stride, stream);
548            break;
549        case 3:
550            launch_mul_mat_vec_f_cuda<T, type_acc, 3>
551                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
552                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
553                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
554                 nsamples_dst, ids_stride, stream);
555            break;
556        case 4:
557            launch_mul_mat_vec_f_cuda<T, type_acc, 4>
558                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
559                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
560                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
561                 nsamples_dst, ids_stride, stream);
562            break;
563        case 5:
564            launch_mul_mat_vec_f_cuda<T, type_acc, 5>
565                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
566                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
567                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
568                 nsamples_dst, ids_stride, stream);
569            break;
570        case 6:
571            launch_mul_mat_vec_f_cuda<T, type_acc, 6>
572                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
573                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
574                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
575                 nsamples_dst, ids_stride, stream);
576            break;
577        case 7:
578            launch_mul_mat_vec_f_cuda<T, type_acc, 7>
579                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
580                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
581                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
582                 nsamples_dst, ids_stride, stream);
583            break;
584        case 8:
585            launch_mul_mat_vec_f_cuda<T, type_acc, 8>
586                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
587                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
588                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
589                 nsamples_dst, ids_stride, stream);
590            break;
591        default:
592            GGML_ABORT("fatal error");
593            break;
594    }
595}
596
597template<typename T>
598static void mul_mat_vec_f_cuda(
599        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
600        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
601        const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
602        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
603        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
604        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
605        const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
606
607    if constexpr(std::is_same_v<T, half>) {
608        if (prec == GGML_PREC_DEFAULT) {
609            mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
610                (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
611                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
612                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
613            return;
614        }
615    }
616    mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
617        (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
618        nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
619        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
620}
621
622void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
623    const ggml_cuda_mm_fusion_args_host * fusion) {
624    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
625    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
626    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
627
628    GGML_TENSOR_BINARY_OP_LOCALS;
629
630    const size_t ts_src0 = ggml_type_size(src0->type);
631    const size_t ts_src1 = ggml_type_size(src1->type);
632    const size_t ts_dst  = ggml_type_size(dst->type);
633
634    GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
635    GGML_ASSERT(ne13 == ne3);
636
637    GGML_ASSERT(        nb00       == ts_src0);
638    GGML_ASSERT(        nb10       == ts_src1);
639    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
640    GGML_ASSERT(        nb0        == ts_dst);
641
642    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
643    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
644
645    const float   * src1_d =       (const float   *) src1->data;
646    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
647    float         *  dst_d =       (float         *)  dst->data;
648
649    ggml_cuda_mm_fusion_args_device fusion_local{};
650
651    if (fusion) {
652        GGML_ASSERT( !ids || dst->ne[2] == 1);
653        GGML_ASSERT(  ids || dst->ne[1] == 1);
654        if (fusion->x_bias) {
655            GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
656            GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
657            GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
658            fusion_local.x_bias = fusion->x_bias->data;
659        }
660        if (fusion->gate) {
661            GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
662            fusion_local.gate = fusion->gate->data;
663        }
664        if (fusion->gate_bias) {
665            GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
666            GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
667            GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
668            fusion_local.gate_bias = fusion->gate_bias->data;
669        }
670        fusion_local.glu_op = fusion->glu_op;
671    }
672
673    const int64_t s01 = src0->nb[1] / ts_src0;
674    const int64_t s11 = src1->nb[1] / ts_src1;
675    const int64_t s1  =  dst->nb[1] / ts_dst;
676    const int64_t s02 = src0->nb[2] / ts_src0;
677    const int64_t s12 = src1->nb[2] / ts_src1;
678    const int64_t s2  =  dst->nb[2] / ts_dst;
679    const int64_t s03 = src0->nb[3] / ts_src0;
680    const int64_t s13 = src1->nb[3] / ts_src1;
681    const int64_t s3  =  dst->nb[3] / ts_dst;
682
683    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
684    const int64_t ncols_dst          = ids ? ne2  : ne1;
685    const int64_t nchannels_y        = ids ? ne11 : ne12;
686    const int64_t nchannels_dst      = ids ? ne1  : ne2;
687    const int64_t stride_col_dst     = ids ? s2   : s1;
688    const int64_t stride_col_y       = ids ? s12  : s11;
689    const int64_t stride_channel_dst = ids ? s1   : s2;
690    const int64_t stride_channel_y   = ids ? s11  : s12;
691
692    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
693
694    switch (src0->type) {
695        case GGML_TYPE_F32: {
696            const float * src0_d = (const float *) src0->data;
697            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
698                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
699                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
700        } break;
701        case GGML_TYPE_F16: {
702            const half * src0_d = (const half *) src0->data;
703            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
704                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
705                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
706        } break;
707        case GGML_TYPE_BF16: {
708            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
709            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
710                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
711                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
712        } break;
713        default:
714            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
715    }
716}
717
718void ggml_cuda_op_mul_mat_vec_f(
719    ggml_backend_cuda_context & ctx,
720    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
721    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
722    const int64_t src1_padded_row_size, cudaStream_t stream) {
723
724    GGML_ASSERT(src1->type == GGML_TYPE_F32);
725    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
726
727    const int64_t ne00 = src0->ne[0];
728    const int64_t ne10 = src1->ne[0];
729    const int64_t ne0  =  dst->ne[0];
730    const int64_t row_diff = row_high - row_low;
731
732    const int id = ggml_cuda_get_device();
733    const int cc = ggml_cuda_info().devices[id].cc;
734    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
735
736    // ggml_cuda_op provides single, contiguous matrices
737    const int64_t stride_row         = ne00;
738    const int64_t stride_col_y       = ne10;
739    const int64_t stride_col_dst     = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
740    const int64_t nchannels_x        = 1;
741    const int64_t nchannels_y        = 1;
742    const int64_t nchannels_dst      = 1;
743    const int64_t stride_channel_x   = 0;
744    const int64_t stride_channel_y   = 0;
745    const int64_t stride_channel_dst = 0;
746    const int64_t nsamples_x         = 1;
747    const int64_t nsamples_dst       = 1;
748    const int64_t stride_sample_x    = 0;
749    const int64_t stride_sample_y    = 0;
750    const int64_t stride_sample_dst  = 0;
751
752    ggml_cuda_mm_fusion_args_device empty{};
753    switch (src0->type) {
754        case GGML_TYPE_F32: {
755            const float * src0_d = (const float *) src0_dd_i;
756            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
757                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
758                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
759        } break;
760        case GGML_TYPE_F16: {
761            const half * src0_d = (const half *) src0_dd_i;
762            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
763                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
764                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
765        } break;
766        case GGML_TYPE_BF16: {
767            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
768            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
769                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
770                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
771        } break;
772        default:
773            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
774    }
775
776    GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
777}
778
779bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
780    if (src0_ne[0] % 2 != 0) {
781        return false;
782    }
783
784    const size_t ts = ggml_type_size(type);
785    if (src0_nb[0] != ts) {
786        return false;
787    }
788
789    // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
790    for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
791        if (src0_nb[i] % (2*ts) != 0) {
792            return false;
793        }
794    }
795
796    switch (type) {
797        case GGML_TYPE_F32:
798            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
799                if (ampere_mma_available(cc)) {
800                    return ne11 <= 3;
801                }
802                if (cc >= GGML_CUDA_CC_TURING) {
803                    return ne11 <= 4;
804                }
805                return ne11 <= 3;
806            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
807                if (fp32_mma_hardware_available(cc)) {
808                    return ne11 <= 3;
809                }
810                return ne11 <= 8;
811            }
812            return ne11 <= 8;
813        case GGML_TYPE_F16:
814            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
815                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
816                if (ampere_mma_available(cc)) {
817                    return src0_small && ne11 == 1;
818                }
819                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
820                    return src0_small && ne11 <= 4;
821                }
822                if (fp16_mma_hardware_available(cc)) {
823                    return src0_small && ne11 <= 3;
824                }
825                return ne11 <= 8;
826            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
827                if (fp16_mma_hardware_available(cc)) {
828                    if (GGML_CUDA_CC_IS_RDNA3(cc)) {
829                        return ne11 <= 3;
830                    }
831                    if (GGML_CUDA_CC_IS_RDNA4(cc)) {
832                        return ne11 <= 5;
833                    }
834                    return ne11 <= 2;
835                }
836                return ne11 <= 8;
837            }
838            return ne11 <= 8;
839        case GGML_TYPE_BF16:
840            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
841                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
842                if (ampere_mma_available(cc)) {
843                    return src0_small && ne11 == 1;
844                }
845                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
846                    return src0_small && ne11 <= 4;
847                }
848                if (bf16_mma_hardware_available(cc)) {
849                    return src0_small && ne11 <= 3;
850                }
851                return ne11 <= 8;
852            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
853                if (bf16_mma_hardware_available(cc)) {
854                    return ne11 <= 3;
855                }
856                return ne11 <= 8;
857            }
858            return ne11 <= 8;
859        default:
860            return false;
861    }
862}