1#include "common.cuh"
  2#include "mmq.cuh"
  3#include "quantize.cuh"
  4#include "mmid.cuh"
  5
  6static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  7    switch (args.type_x) {
  8        case GGML_TYPE_Q4_0:
  9            mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
 10            break;
 11        case GGML_TYPE_Q4_1:
 12            mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
 13            break;
 14        case GGML_TYPE_Q5_0:
 15            mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
 16            break;
 17        case GGML_TYPE_Q5_1:
 18            mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
 19            break;
 20        case GGML_TYPE_Q8_0:
 21            mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
 22            break;
 23        case GGML_TYPE_MXFP4:
 24            mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
 25            break;
 26        case GGML_TYPE_Q2_K:
 27            mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
 28            break;
 29        case GGML_TYPE_Q3_K:
 30            mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
 31            break;
 32        case GGML_TYPE_Q4_K:
 33            mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
 34            break;
 35        case GGML_TYPE_Q5_K:
 36            mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
 37            break;
 38        case GGML_TYPE_Q6_K:
 39            mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
 40            break;
 41        case GGML_TYPE_IQ2_XXS:
 42            mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
 43            break;
 44        case GGML_TYPE_IQ2_XS:
 45            mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
 46            break;
 47        case GGML_TYPE_IQ2_S:
 48            mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
 49            break;
 50        case GGML_TYPE_IQ3_XXS:
 51            mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
 52            break;
 53        case GGML_TYPE_IQ3_S:
 54            mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
 55            break;
 56        case GGML_TYPE_IQ1_S:
 57            mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
 58            break;
 59        case GGML_TYPE_IQ4_XS:
 60            mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
 61            break;
 62        case GGML_TYPE_IQ4_NL:
 63            mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
 64            break;
 65        default:
 66            GGML_ABORT("fatal error");
 67            break;
 68    }
 69}
 70
 71void ggml_cuda_mul_mat_q(
 72        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
 73    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
 74    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);
 75    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
 76
 77    GGML_TENSOR_BINARY_OP_LOCALS;
 78
 79    cudaStream_t stream = ctx.stream();
 80    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 81
 82    const size_t ts_src0 = ggml_type_size(src0->type);
 83    const size_t ts_src1 = ggml_type_size(src1->type);
 84    const size_t ts_dst  = ggml_type_size(dst->type);
 85
 86    GGML_ASSERT(        nb00       == ts_src0);
 87    GGML_ASSERT(        nb10       == ts_src1);
 88    GGML_ASSERT(        nb0        == ts_dst);
 89    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 90
 91    const char  * src0_d = (const char  *) src0->data;
 92    const float * src1_d = (const float *) src1->data;
 93    float       *  dst_d = (float       *)  dst->data;
 94
 95    // If src0 is a temporary compute buffer, clear any potential padding.
 96    if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
 97        const size_t size_data  = ggml_nbytes(src0);
 98        const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
 99        if (size_alloc > size_data) {
100            GGML_ASSERT(ggml_is_contiguously_allocated(src0));
101            GGML_ASSERT(!src0->view_src);
102            CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
103        }
104    }
105
106    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
107
108    const int64_t s01 = src0->nb[1] / ts_src0;
109    const int64_t s1  =  dst->nb[1] / ts_dst;
110    const int64_t s02 = src0->nb[2] / ts_src0;
111    const int64_t s2  =  dst->nb[2] / ts_dst;
112    const int64_t s03 = src0->nb[3] / ts_src0;
113    const int64_t s3  =  dst->nb[3] / ts_dst;
114
115    const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
116                            || GGML_CUDA_CC_IS_CDNA(cc);
117
118    // TODO: tighter pool buffer size vs q8 path
119    const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
120
121    if (!ids) {
122        const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
123            get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
124        ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
125
126        {
127            const int64_t s11 = src1->nb[1] / ts_src1;
128            const int64_t s12 = src1->nb[2] / ts_src1;
129            const int64_t s13 = src1->nb[3] / ts_src1;
130            if (use_native_mxfp4) {
131                static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
132                quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
133                                        ne11, ne12, ne13, stream);
134
135            } else {
136                quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
137                                       ne11, ne12, ne13, stream);
138            }
139            CUDA_CHECK(cudaGetLastError());
140        }
141
142        // Stride depends on quantization format
143        const int64_t s12 = use_native_mxfp4 ?
144                                ne11 * ne10_padded * sizeof(block_fp4_mmq) /
145                                    (8 * QK_MXFP4 * sizeof(int))  // block_fp4_mmq holds 256 values (8 blocks of 32)
146                                :
147                                ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
148        const int64_t s13 = ne12*s12;
149
150        const mmq_args args = {
151            src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
152            ne00, ne01, ne1, s01, ne11, s1,
153            ne02, ne12, s02, s12, s2,
154            ne03, ne13, s03, s13, s3,
155            use_stream_k, ne1};
156        ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
157        return;
158    }
159
160    GGML_ASSERT(ne13 == 1);
161    GGML_ASSERT(nb12 % nb11 == 0);
162    GGML_ASSERT(nb2  % nb1  == 0);
163
164    const int64_t n_expert_used = ids->ne[0];
165    const int64_t ne_get_rows = ne12 * n_expert_used;
166    GGML_ASSERT(ne1 == n_expert_used);
167
168    ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
169    ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
170    ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
171
172    {
173        GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
174        const int si1  = ids->nb[1] / ggml_element_size(ids);
175        const int sis1 = nb12 / nb11;
176
177        ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
178            ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
179        CUDA_CHECK(cudaGetLastError());
180    }
181
182    const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
183        get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
184    ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
185
186    const int64_t ne11_flat = ne12*n_expert_used;
187    const int64_t ne12_flat = 1;
188    const int64_t ne13_flat = 1;
189
190    {
191        const int64_t s11 = src1->nb[1] / ts_src1;
192        const int64_t s12 = src1->nb[2] / ts_src1;
193        const int64_t s13 = src1->nb[3] / ts_src1;
194
195        if (use_native_mxfp4) {
196            quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
197                                    ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
198        } else {
199            quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
200                                   ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
201        }
202        CUDA_CHECK(cudaGetLastError());
203    }
204
205    const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
206                                           ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
207    const int64_t s13 = ne12*s12;
208
209    // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
210    const mmq_args args = {
211        src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
212        ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
213        ne02, ne02, s02, s12, s2,
214        ne03, ne13, s03, s13, s3,
215        use_stream_k, ne12};
216
217    ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
218}
219
220void ggml_cuda_op_mul_mat_q(
221    ggml_backend_cuda_context & ctx,
222    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
223    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
224    const int64_t src1_padded_row_size, cudaStream_t stream) {
225
226    const int64_t ne00 = src0->ne[0];
227
228    const int64_t ne10 = src1->ne[0];
229    const int64_t ne11 = src1->ne[1];
230    GGML_ASSERT(ne10 % QK8_1 == 0);
231
232    const int64_t ne0 = dst->ne[0];
233
234    const int64_t row_diff = row_high - row_low;
235    const int64_t stride01 = ne00 / ggml_blck_size(src0->type);
236
237    const int id = ggml_cuda_get_device();
238    const int cc = ggml_cuda_info().devices[id].cc;
239
240    // the main device has a larger memory buffer to hold the results from all GPUs
241    // nrows_dst == nrows of the matrix that the kernel writes into
242    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
243
244    // The stream-k decomposition is only faster for recent NVIDIA GPUs.
245    // Also its fixup needs to allocate a temporary buffer in the memory pool.
246    // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
247    const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
248                            || GGML_CUDA_CC_IS_CDNA(cc))
249                            && src1_ncols == ne11;
250    const mmq_args args = {
251        src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
252        ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
253        1, 1, 0, 0, 0,
254        1, 1, 0, 0, 0,
255        use_stream_k, src1_ncols};
256
257    ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
258
259    GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
260}
261
262bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
263#ifdef GGML_CUDA_FORCE_CUBLAS
264    return false;
265#endif // GGML_CUDA_FORCE_CUBLAS
266
267    bool mmq_supported;
268
269    switch (type) {
270        case GGML_TYPE_Q4_0:
271        case GGML_TYPE_Q4_1:
272        case GGML_TYPE_Q5_0:
273        case GGML_TYPE_Q5_1:
274        case GGML_TYPE_Q8_0:
275        case GGML_TYPE_MXFP4:
276        case GGML_TYPE_Q2_K:
277        case GGML_TYPE_Q3_K:
278        case GGML_TYPE_Q4_K:
279        case GGML_TYPE_Q5_K:
280        case GGML_TYPE_Q6_K:
281        case GGML_TYPE_IQ2_XXS:
282        case GGML_TYPE_IQ2_XS:
283        case GGML_TYPE_IQ2_S:
284        case GGML_TYPE_IQ3_XXS:
285        case GGML_TYPE_IQ3_S:
286        case GGML_TYPE_IQ1_S:
287        case GGML_TYPE_IQ4_XS:
288        case GGML_TYPE_IQ4_NL:
289            mmq_supported = true;
290            break;
291        default:
292            mmq_supported = false;
293            break;
294    }
295
296    if (!mmq_supported) {
297        return false;
298    }
299
300    if (turing_mma_available(cc)) {
301        return true;
302    }
303
304    if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
305        return false;
306    }
307
308#ifdef GGML_CUDA_FORCE_MMQ
309    return true;
310#endif //GGML_CUDA_FORCE_MMQ
311
312    if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
313        return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
314    }
315
316    if (amd_mfma_available(cc)) {
317        // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
318        // performs better but is currently suffering from a crash on this architecture.
319        // TODO: Revisit when hipblaslt is fixed on CDNA3
320        if (GGML_CUDA_CC_IS_CDNA3(cc)) {
321            return true;
322        }
323        if (n_experts > 64 || ne11 <= 128) {
324            return true;
325        }
326        if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
327            return true;
328        }
329        if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
330            return true;
331        }
332        return false;
333    }
334
335    if (amd_wmma_available(cc)) {
336        if (GGML_CUDA_CC_IS_RDNA3(cc)) {
337            // High expert counts are almost always better on MMQ due to
338            //     the synchronization overhead in the cuBLAS/hipBLAS path:
339            // https://github.com/ggml-org/llama.cpp/pull/18202
340            if (n_experts >= 64) {
341                return true;
342            }
343
344            // For some quantization types MMQ can have lower peak TOPS than hipBLAS
345            //     so it's only faster for sufficiently small batch sizes:
346            switch (type) {
347                case GGML_TYPE_Q2_K:
348                    return ne11 <= 128;
349                case GGML_TYPE_Q6_K:
350                    return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);
351                case GGML_TYPE_IQ2_XS:
352                case GGML_TYPE_IQ2_S:
353                    return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;
354                default:
355                    return true;
356            }
357        }
358
359        // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:
360        // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
361        return true;
362    }
363
364    return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
365
366}