1#include "mmvq.cuh"
  2#include "quantize.cuh"
  3#include "unary.cuh"
  4#include "vecdotq.cuh"
  5
  6#include <cstdint>
  7
  8typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
  9
 10static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
 11    switch (type) {
 12        case GGML_TYPE_Q4_0:    return vec_dot_q4_0_q8_1;
 13        case GGML_TYPE_Q4_1:    return vec_dot_q4_1_q8_1;
 14        case GGML_TYPE_Q5_0:    return vec_dot_q5_0_q8_1;
 15        case GGML_TYPE_Q5_1:    return vec_dot_q5_1_q8_1;
 16        case GGML_TYPE_Q8_0:    return vec_dot_q8_0_q8_1;
 17        case GGML_TYPE_MXFP4:   return vec_dot_mxfp4_q8_1;
 18        case GGML_TYPE_Q2_K:    return vec_dot_q2_K_q8_1;
 19        case GGML_TYPE_Q3_K:    return vec_dot_q3_K_q8_1;
 20        case GGML_TYPE_Q4_K:    return vec_dot_q4_K_q8_1;
 21        case GGML_TYPE_Q5_K:    return vec_dot_q5_K_q8_1;
 22        case GGML_TYPE_Q6_K:    return vec_dot_q6_K_q8_1;
 23        case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
 24        case GGML_TYPE_IQ2_XS:  return vec_dot_iq2_xs_q8_1;
 25        case GGML_TYPE_IQ2_S:   return vec_dot_iq2_s_q8_1;
 26        case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
 27        case GGML_TYPE_IQ1_S:   return vec_dot_iq1_s_q8_1;
 28        case GGML_TYPE_IQ1_M:   return vec_dot_iq1_m_q8_1;
 29        case GGML_TYPE_IQ4_NL:  return vec_dot_iq4_nl_q8_1;
 30        case GGML_TYPE_IQ4_XS:  return vec_dot_iq4_xs_q8_1;
 31        case GGML_TYPE_IQ3_S:   return vec_dot_iq3_s_q8_1;
 32        default:                return nullptr;
 33    }
 34}
 35
 36static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
 37    switch (type) {
 38        case GGML_TYPE_Q4_0:    return VDR_Q4_0_Q8_1_MMVQ;
 39        case GGML_TYPE_Q4_1:    return VDR_Q4_1_Q8_1_MMVQ;
 40        case GGML_TYPE_Q5_0:    return VDR_Q5_0_Q8_1_MMVQ;
 41        case GGML_TYPE_Q5_1:    return VDR_Q5_1_Q8_1_MMVQ;
 42        case GGML_TYPE_Q8_0:    return VDR_Q8_0_Q8_1_MMVQ;
 43        case GGML_TYPE_MXFP4:   return VDR_MXFP4_Q8_1_MMVQ;
 44        case GGML_TYPE_Q2_K:    return VDR_Q2_K_Q8_1_MMVQ;
 45        case GGML_TYPE_Q3_K:    return VDR_Q3_K_Q8_1_MMVQ;
 46        case GGML_TYPE_Q4_K:    return VDR_Q4_K_Q8_1_MMVQ;
 47        case GGML_TYPE_Q5_K:    return VDR_Q5_K_Q8_1_MMVQ;
 48        case GGML_TYPE_Q6_K:    return VDR_Q6_K_Q8_1_MMVQ;
 49        case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
 50        case GGML_TYPE_IQ2_XS:  return VDR_IQ2_XS_Q8_1_MMVQ;
 51        case GGML_TYPE_IQ2_S:   return VDR_IQ2_S_Q8_1_MMVQ;
 52        case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
 53        case GGML_TYPE_IQ3_S:   return VDR_IQ3_S_Q8_1_MMVQ;
 54        case GGML_TYPE_IQ4_NL:  return VDR_IQ4_NL_Q8_1_MMVQ;
 55        case GGML_TYPE_IQ4_XS:  return VDR_IQ4_XS_Q8_1_MMVQ;
 56        default:                return 1;
 57    }
 58}
 59
 60enum mmvq_parameter_table_id {
 61    MMVQ_PARAMETERS_GENERIC = 0,
 62    MMVQ_PARAMETERS_GCN,
 63    MMVQ_PARAMETERS_RDNA2
 64};
 65
 66static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
 67#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
 68    return MMVQ_PARAMETERS_RDNA2;
 69#elif defined(GCN) || defined(CDNA)
 70    return MMVQ_PARAMETERS_GCN;
 71#else
 72    return MMVQ_PARAMETERS_GENERIC;
 73#endif
 74}
 75
 76static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
 77    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
 78        return MMVQ_PARAMETERS_RDNA2;
 79    }
 80    if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
 81        return MMVQ_PARAMETERS_GCN;
 82    }
 83    return MMVQ_PARAMETERS_GENERIC;
 84}
 85
 86static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
 87    if (table_id == MMVQ_PARAMETERS_GENERIC) {
 88        switch (ncols_dst) {
 89            case 1:
 90            case 2:
 91            case 3:
 92            case 4:
 93                return 4;
 94            case 5:
 95            case 6:
 96            case 7:
 97            case 8:
 98                return 2;
 99            default:
100                return 1;
101        }
102    } else if (table_id == MMVQ_PARAMETERS_GCN) {
103        switch (ncols_dst) {
104            case 1:
105            case 2:
106            case 3:
107            case 4:
108                return 2;
109            case 5:
110            case 6:
111            case 7:
112            case 8:
113            default:
114                return 1;
115        }
116    }
117    return 1;
118}
119
120static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
121    if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
122        switch (ncols_dst) {
123            case 1:
124                return 1;
125            case 2:
126            case 3:
127            case 4:
128            case 5:
129            case 6:
130            case 7:
131            case 8:
132                return 2;
133            default:
134                return 1;
135        }
136    }
137    return 1;
138}
139
140template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
141__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142static __global__ void mul_mat_vec_q(
143        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
144        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
148        const uint32_t ids_stride) {
149
150    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
151    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
152    constexpr int vdr = get_vdr_mmvq(type);
153    constexpr mmvq_parameter_table_id table_id = get_device_table_id();
154    constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
155    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
156    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
157
158    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
159
160    const     int tid = warp_size*threadIdx.y + threadIdx.x;
161    const     int row0 = rows_per_cuda_block*blockIdx.x;
162    const     int blocks_per_row_x = ncols_x / qk;
163    constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
164
165    const uint32_t channel_dst = blockIdx.y;
166
167    uint32_t token_idx = 0;
168    uint32_t channel_x;
169    uint32_t channel_y;
170    uint32_t sample_dst;
171
172    if constexpr (is_multi_token_id) {
173        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
174        token_idx  = blockIdx.z;
175        channel_x  = ids[channel_dst + token_idx * ids_stride];
176        channel_y  = fastmodulo(channel_dst, nchannels_y);
177        sample_dst = 0;
178    } else {
179        channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
180        channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
181        sample_dst = blockIdx.z;
182    }
183
184    const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
185    const uint32_t sample_y    = sample_dst;
186
187    bool use_gate = false;
188    bool use_bias = false;
189    bool use_gate_bias = false;
190    const void * vgate = nullptr;
191    const float * x_bias = nullptr;
192    const float * gate_bias = nullptr;
193    ggml_glu_op active_glu;
194
195    if constexpr (has_fusion) {
196        use_gate      = fusion.gate      != nullptr;
197        use_bias      = fusion.x_bias    != nullptr;
198        use_gate_bias = fusion.gate_bias != nullptr && use_gate;
199        vgate         = fusion.gate;
200        x_bias        = (const float *) fusion.x_bias;
201        gate_bias     = (const float *) fusion.gate_bias;
202        active_glu    = fusion.glu_op;
203    }
204
205
206    float x_biases[ncols_dst]    = { 0.0f };
207    float gate_biases[ncols_dst] = { 0.0f };
208    if constexpr (has_fusion) {
209        const uint32_t channel_bias = ids ? channel_x : channel_dst;
210        if (use_bias) {
211            x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
212            // 1. Hide latency by prefetching bias and gate here
213            // 2. load only on threads that won't die after partial sum calculation
214            if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
215                (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
216#pragma unroll
217                for (int j = 0; j < ncols_dst; ++j) {
218                    x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
219                }
220            }
221        }
222        if (use_gate_bias) {
223            gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
224            if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
225                (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
226#pragma unroll
227                for (int j = 0; j < ncols_dst; ++j) {
228                    gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
229                }
230            }
231        }
232    }
233
234    // partial sum for each thread
235    float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
236    float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
237
238    const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
239    if constexpr (is_multi_token_id) {
240        y += token_idx*stride_col_y;
241    }
242    const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
243
244    for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
245        const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
246
247        // x block quant index when casting the quants to int
248        const int kqs = vdr * (tid % (qi/vdr));
249
250#pragma unroll
251        for (int j = 0; j < ncols_dst; ++j) {
252#pragma unroll
253            for (int i = 0; i < rows_per_cuda_block; ++i) {
254                tmp[j][i] += vec_dot_q_cuda(
255                    vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
256                if constexpr (has_fusion) {
257                    if (use_gate) {
258                        tmp_gate[j][i] += vec_dot_q_cuda(
259                            vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
260                    }
261                }
262            }
263        }
264    }
265
266    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
267    __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
268    if constexpr (!has_fusion) {
269        (void) tmp_shared_gate;
270    } else if (!use_gate) {
271        (void) tmp_shared_gate;
272    }
273
274    if (threadIdx.y > 0) {
275#pragma unroll
276        for (int j = 0; j < ncols_dst; ++j) {
277#pragma unroll
278            for (int i = 0; i < rows_per_cuda_block; ++i) {
279                tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
280                if constexpr (has_fusion) {
281                    if (use_gate) {
282                        tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
283                    }
284                }
285            }
286        }
287    }
288    __syncthreads();
289    if (threadIdx.y > 0) {
290        return;
291    }
292
293    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
294
295    if constexpr (is_multi_token_id) {
296        dst += token_idx*stride_col_dst;
297    }
298
299    // sum up partial sums and write back result
300#pragma unroll
301    for (int j = 0; j < ncols_dst; ++j) {
302#pragma unroll
303        for (int i = 0; i < rows_per_cuda_block; ++i) {
304#pragma unroll
305            for (int l = 0; l < nwarps-1; ++l) {
306                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
307                if constexpr (has_fusion) {
308                    if (use_gate) {
309                        tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
310                    }
311                }
312            }
313            tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
314            if constexpr (has_fusion) {
315                if (use_gate) {
316                    tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
317                }
318            }
319        }
320
321        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
322            float result = tmp[j][threadIdx.x];
323            if constexpr (has_fusion) {
324                if (use_bias) {
325                    result += x_biases[j];
326                }
327                if (use_gate) {
328                    float gate_value = tmp_gate[j][threadIdx.x];
329                    if (use_gate_bias) {
330                        gate_value += gate_biases[j];
331                    }
332                    switch (active_glu) {
333                        case GGML_GLU_OP_SWIGLU:
334                            result *= ggml_cuda_op_silu_single(gate_value);
335                            break;
336                        case GGML_GLU_OP_GEGLU:
337                            result *= ggml_cuda_op_gelu_single(gate_value);
338                            break;
339                        case GGML_GLU_OP_SWIGLU_OAI: {
340                            result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
341                            break;
342                        }
343                        default:
344                            result = result * gate_value;
345                            break;
346                    }
347                }
348            }
349            dst[j*stride_col_dst + threadIdx.x] = result;
350        }
351    }
352
353    if constexpr (!has_fusion) {
354        GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
355    }
356}
357
358static std::pair<dim3, dim3> calc_launch_params(
359        const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
360        const int warp_size, const mmvq_parameter_table_id table_id) {
361    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
362    const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
363    const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
364    return {block_nums, block_dims};
365}
366
367template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
368static void mul_mat_vec_q_switch_fusion(
369        const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
370        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
371        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
372        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
373        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
374        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
375        const uint32_t ids_stride, cudaStream_t stream) {
376
377    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
378    if constexpr (c_ncols_dst == 1) {
379        if (has_fusion) {
380            mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
381                (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
382                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
383                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
384            return;
385        }
386    }
387
388    GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
389
390    mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
391        (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
392        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
393        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
394}
395
396template <ggml_type type>
397static void mul_mat_vec_q_switch_ncols_dst(
398        const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
399        const int ncols_x, const int nrows_x, const int ncols_dst,
400        const int stride_row_x, const int stride_col_y, const int stride_col_dst,
401        const int nchannels_x, const int nchannels_y, const int nchannels_dst,
402        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
403        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
404        const int ids_stride, cudaStream_t stream) {
405
406    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
407    GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
408
409    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
410    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0)              : init_fastdiv_values(nchannels_dst / nchannels_x);
411    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
412
413    const int device = ggml_cuda_get_device();
414    const int warp_size = ggml_cuda_info().devices[device].warp_size;
415    const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
416
417    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
418    const bool has_ids = ids != nullptr;
419
420    if (has_ids && ncols_dst > 1) {
421        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
422        constexpr int c_ncols_dst = 1;
423        std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
424        mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
425             channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
426             sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
427             dims.first, dims.second, 0, ids_stride, stream);
428        return;
429    }
430
431    switch (ncols_dst) {
432        case 1: {
433            constexpr int c_ncols_dst = 1;
434            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
435            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
436                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
438                 dims.first, dims.second, 0, ids_stride, stream);
439        } break;
440        case 2: {
441            constexpr int c_ncols_dst = 2;
442            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
443            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
444                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
445                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
446                 dims.first, dims.second, 0, ids_stride, stream);
447        } break;
448        case 3: {
449            constexpr int c_ncols_dst = 3;
450            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
451            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
452                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
453                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
454                 dims.first, dims.second, 0, ids_stride, stream);
455        } break;
456        case 4: {
457            constexpr int c_ncols_dst = 4;
458            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
459            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
460                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
462                 dims.first, dims.second, 0, ids_stride, stream);
463        } break;
464        case 5: {
465            constexpr int c_ncols_dst = 5;
466            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
467            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
468                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
469                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
470                 dims.first, dims.second, 0, ids_stride, stream);
471        } break;
472        case 6: {
473            constexpr int c_ncols_dst = 6;
474            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
475            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
476                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
477                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
478                 dims.first, dims.second, 0, ids_stride, stream);
479        } break;
480        case 7: {
481            constexpr int c_ncols_dst = 7;
482            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
483            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
486                 dims.first, dims.second, 0, ids_stride, stream);
487        } break;
488        case 8: {
489            constexpr int c_ncols_dst = 8;
490            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
492                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
493                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
494                 dims.first, dims.second, 0, ids_stride, stream);
495        } break;
496        default:
497            GGML_ABORT("fatal error");
498            break;
499    }
500
501    GGML_UNUSED(has_fusion);
502}
503static void mul_mat_vec_q_switch_type(
504        const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
505        const int ncols_x, const int nrows_x, const int ncols_dst,
506        const int stride_row_x, const int stride_col_y, const int stride_col_dst,
507        const int nchannels_x, const int nchannels_y, const int nchannels_dst,
508        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
509        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
510        const int ids_stride, cudaStream_t stream) {
511    switch (type_x) {
512        case GGML_TYPE_Q4_0:
513            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
514                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
515                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
516                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
517            break;
518        case GGML_TYPE_Q4_1:
519            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
520                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
521                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
522                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
523            break;
524        case GGML_TYPE_Q5_0:
525            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
526                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
527                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
528                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
529            break;
530        case GGML_TYPE_Q5_1:
531            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
532                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
533                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
534                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
535            break;
536        case GGML_TYPE_Q8_0:
537            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
538                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
539                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
540                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
541            break;
542        case GGML_TYPE_MXFP4:
543            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
544                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
545                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
546                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
547            break;
548        case GGML_TYPE_Q2_K:
549            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
550                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
551                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
552                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
553            break;
554        case GGML_TYPE_Q3_K:
555            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
556                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
557                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
558                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
559            break;
560        case GGML_TYPE_Q4_K:
561            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
562                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
563                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
564                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
565            break;
566        case GGML_TYPE_Q5_K:
567            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
568                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
569                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
570                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
571            break;
572        case GGML_TYPE_Q6_K:
573            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
574                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
575                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
576                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
577            break;
578        case GGML_TYPE_IQ2_XXS:
579            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
580                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
581                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
582                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
583            break;
584        case GGML_TYPE_IQ2_XS:
585            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
586                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
587                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
588                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
589            break;
590        case GGML_TYPE_IQ2_S:
591            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
592                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
593                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
594                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
595            break;
596        case GGML_TYPE_IQ3_XXS:
597            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
598                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
599                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
600                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
601            break;
602        case GGML_TYPE_IQ1_S:
603            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
604                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
605                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
606                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
607            break;
608        case GGML_TYPE_IQ1_M:
609            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
610                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
611                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
612                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
613            break;
614        case GGML_TYPE_IQ4_NL:
615            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
616                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
617                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
618                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
619            break;
620        case GGML_TYPE_IQ4_XS:
621            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
622                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
623                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
624                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
625            break;
626        case GGML_TYPE_IQ3_S:
627            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
628                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
629                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
630                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
631            break;
632        default:
633            GGML_ABORT("fatal error");
634            break;
635    }
636}
637
638void ggml_cuda_mul_mat_vec_q(
639        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
640        const ggml_cuda_mm_fusion_args_host * fusion) {
641    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
642    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);
643    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
644
645    GGML_TENSOR_BINARY_OP_LOCALS;
646
647    cudaStream_t stream = ctx.stream();
648
649    const size_t ts_src0 = ggml_type_size(src0->type);
650    const size_t ts_src1 = ggml_type_size(src1->type);
651    const size_t ts_dst  = ggml_type_size(dst->type);
652
653    GGML_ASSERT(        nb00       == ts_src0);
654    GGML_ASSERT(        nb10       == ts_src1);
655    GGML_ASSERT(        nb0        == ts_dst);
656    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
657
658    GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
659
660    const float   * src1_d =       (const float   *) src1->data;
661    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
662    float         *  dst_d =       (float         *)  dst->data;
663
664    ggml_cuda_mm_fusion_args_device fusion_local{};
665
666    if (fusion) {
667        GGML_ASSERT( !ids || dst->ne[2] == 1);
668        GGML_ASSERT(  ids || dst->ne[1] == 1);
669
670        if (fusion->x_bias) {
671            GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
672            GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
673            GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
674            fusion_local.x_bias = fusion->x_bias->data;
675        }
676        if (fusion->gate) {
677            GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
678            fusion_local.gate = fusion->gate->data;
679        }
680        if (fusion->gate_bias) {
681            GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
682            GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
683            GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
684            fusion_local.gate_bias = fusion->gate_bias->data;
685        }
686        fusion_local.glu_op = fusion->glu_op;
687    }
688
689    // If src0 is a temporary compute buffer, clear any potential padding.
690    if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
691        const size_t size_data  = ggml_nbytes(src0);
692        const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
693        if (size_alloc > size_data) {
694            GGML_ASSERT(ggml_is_contiguously_allocated(src0));
695            GGML_ASSERT(!src0->view_src);
696            CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
697        }
698    }
699
700    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
701    ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
702    {
703        const int64_t s11 = src1->nb[1] / ts_src1;
704        const int64_t s12 = src1->nb[2] / ts_src1;
705        const int64_t s13 = src1->nb[3] / ts_src1;
706        quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
707    }
708
709    const int64_t s01 = src0->nb[1] / ts_src0;
710    const int64_t s11 = ne10_padded / QK8_1;
711    const int64_t s1  =  dst->nb[1] / ts_dst;
712    const int64_t s02 = src0->nb[2] / ts_src0;
713    const int64_t s2  =  dst->nb[2] / ts_dst;
714    const int64_t s03 = src0->nb[3] / ts_src0;
715    const int64_t s3  =  dst->nb[3] / ts_dst;
716
717    const int64_t s12 = ne11*s11;
718    const int64_t s13 = ne12*s12;
719
720    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
721    const int64_t ncols_dst          = ids ? ne2  : ne1;
722    const int64_t nchannels_y        = ids ? ne11 : ne12;
723    const int64_t nchannels_dst      = ids ? ne1  : ne2;
724    const int64_t stride_col_dst     = ids ? s2   : s1;
725    const int64_t stride_col_y       = ids ? s12  : s11;
726    const int64_t stride_channel_dst = ids ? s1   : s2;
727    const int64_t stride_channel_y   = ids ? s11  : s12;
728
729    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
730
731    mul_mat_vec_q_switch_type(
732        src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
733        ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
734        ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
735        ne03,              ne3,           s03, s13,              s3,               ids_stride, stream);
736}
737
738void ggml_cuda_op_mul_mat_vec_q(
739    ggml_backend_cuda_context & ctx,
740    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
741    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
742    const int64_t src1_padded_row_size, cudaStream_t stream) {
743
744    const int64_t ne00 = src0->ne[0];
745    const int64_t row_diff = row_high - row_low;
746
747    const int64_t ne10 = src1->ne[0];
748    GGML_ASSERT(ne10 % QK8_1 == 0);
749
750    const int64_t ne0 = dst->ne[0];
751
752    int id = ggml_cuda_get_device();
753
754    // the main device has a larger memory buffer to hold the results from all GPUs
755    // nrows_dst == nrows of the matrix that the kernel writes into
756    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
757
758    const int stride_row_x = ne00 / ggml_blck_size(src0->type);
759    const int stride_col_y = src1_padded_row_size / QK8_1;
760
761    ggml_cuda_mm_fusion_args_device fusion_local{};
762    mul_mat_vec_q_switch_type(
763        src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
764        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
765
766    GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
767}