1#pragma once
  2
  3#include "mma.cuh"
  4#include "common.cuh"
  5#include "convert.cuh"
  6
  7using namespace ggml_cuda_mma;
  8
  9#define MMF_ROWS_PER_BLOCK 32
 10#define MMF_ROWS_PER_BLOCK_CDNA 64
 11
 12static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
 13    if (GGML_CUDA_CC_IS_CDNA(cc)) {
 14        return 512;
 15    } else {
 16        return 256;
 17    }
 18}
 19
 20static __forceinline__ int mmf_get_padding(int cc) {
 21    if (GGML_CUDA_CC_IS_CDNA(cc)) {
 22        return 2;
 23    } else {
 24        return 4;
 25    }
 26}
 27
 28static constexpr __device__ int mmf_get_padding() {
 29#if defined(AMD_MFMA_AVAILABLE)
 30    return 2;
 31#else
 32    return 4;
 33#endif // defined(AMD_MFMA_AVAILABLE)
 34}
 35
 36struct mmf_ids_data {
 37    const int32_t * ids_src_compact = nullptr;
 38    const int32_t * ids_dst_compact = nullptr;
 39    const int32_t * expert_bounds_dev = nullptr;
 40    int n_experts = 0;
 41    int sis1 = 0;
 42};
 43
 44void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 45
 46bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);
 47
 48template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
 49__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
 50static __global__ void mul_mat_f(
 51        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
 52        const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
 53        const int stride_col_id, const int stride_row_id,
 54        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
 55        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 56// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
 57#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 58#if defined(AMD_WMMA_AVAILABLE)
 59    if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
 60    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;
 61    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;
 62    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;
 63#elif defined(AMD_MFMA_AVAILABLE)
 64    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
 65    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;
 66    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;
 67    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
 68#else
 69#ifdef VOLTA_MMA_AVAILABLE
 70    if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
 71    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
 72    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
 73    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
 74#else
 75    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
 76    typedef tile<16, 8, T>     tile_A;
 77    typedef tile<8,  8, T>     tile_B;
 78    typedef tile<16, 8, float> tile_C;
 79#endif // VOLTA_MMA_AVAILABLE
 80#endif // defined(AMD_WMMA_AVAILABLE)
 81    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
 82        NO_DEVICE_CODE;
 83        return;
 84    }
 85
 86    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 87    constexpr int tile_k_padded = warp_size + mmf_get_padding();
 88    constexpr int ntA = rows_per_block / tile_A::I;
 89    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
 90
 91    const int row0        = blockIdx.x * rows_per_block;
 92
 93    int expert_idx = 0;
 94    int col_base = 0;
 95
 96    const int channel_dst = has_ids ? 0 : blockIdx.y;
 97
 98    if constexpr (has_ids) {
 99        // experts + tiles of ncols_dst are packed in the y dimension
100        int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
101        const int nchannels_x = gridDim.y / col_tiles;
102        const int tile_idx = blockIdx.y / nchannels_x;
103        expert_idx = blockIdx.y - tile_idx * nchannels_x;
104        col_base = tile_idx * cols_per_block;
105    }
106
107    const int channel_x   = has_ids ? expert_idx : (channel_dst / channel_ratio);
108    const int channel_y   = channel_dst;
109    const int sample_dst  = blockIdx.z;
110    const int sample_x    = sample_dst / sample_ratio;
111    const int sample_y    = sample_dst;
112
113    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x  + row0*stride_row ;
114    y   += int64_t(sample_y)  *stride_sample_y   + (has_ids ? 0 : channel_y  *stride_channel_y);
115    dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
116
117    if constexpr (has_ids) {
118        constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
119        const int64_t col_offset = col_base;
120        y   += col_offset * stride_col_y * y_stride_scale;
121        dst += col_offset * stride_col_dst;
122        ids += col_offset * stride_row_id;
123    }
124
125    const float2 * y2 = (const float2 *) y;
126
127    extern __shared__ char data_mmv[];
128
129    char * shmem_base = data_mmv;
130    int  * slot_map   = (int *) shmem_base;
131    char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
132
133    tile_C C[ntA][ntB];
134
135    T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
136
137    if constexpr (has_ids) {
138        int found = 0;
139
140        for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
141            const int j = j0 + threadIdx.y;
142
143            if (threadIdx.x == 0) {
144                slot_map[j] = -1;
145            }
146
147            if (col_base + j >= ncols_dst_total) {
148                continue;
149            }
150
151            const int32_t * __restrict__ id_row = ids + j*stride_row_id;
152
153            for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
154                int match = id_row[k*stride_col_id] == expert_idx;
155
156                if (match) {
157                    slot_map[j] = k;
158                    found = 1;
159                    break;
160                }
161            }
162        }
163
164        if (!__syncthreads_or(found)) {
165            return;
166        }
167    }
168
169
170    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
171        tile_A A[ntA][warp_size / tile_A::J];
172#pragma unroll
173        for (int itA = 0; itA < ntA; ++itA) {
174#pragma unroll
175            for (int i = 0; i < tile_A::I; ++i) {
176                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
177            }
178#pragma unroll
179            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
180                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
181            }
182        }
183
184#pragma unroll
185        for (int itB = 0; itB < ntB; ++itB) {
186            if constexpr (std::is_same_v<T, float>) {
187#pragma unroll
188                for (int j0 = 0; j0 < tile_B::I; ++j0) {
189                    const int j = j0 + itB*tile_B::I;
190
191                    if constexpr (!has_ids) {
192                        tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
193                    } else {
194                        const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
195                        tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
196                    }
197                }
198            } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
199#pragma unroll
200                for (int j0 = 0; j0 < tile_B::I; ++j0) {
201                    const int j = j0 + itB*tile_B::I;
202
203                    if constexpr (!has_ids) {
204                        const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
205                        tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
206                    } else {
207                        const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
208                        float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
209                        tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
210                    }
211                }
212            } else {
213                static_assert(std::is_same_v<T, void>, "unsupported type");
214            }
215#pragma unroll
216            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
217                tile_B B;
218                load_ldmatrix(B, tile_xy + k0, tile_k_padded);
219#pragma unroll
220                for (int itA = 0; itA < ntA; ++itA) {
221                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);
222                }
223            }
224        }
225    }
226
227    float * buf_iw = (float *) compute_base;
228    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
229
230    if (nwarps > 1) {
231        __syncthreads();
232    }
233#pragma unroll
234    for (int itB = 0; itB < ntB; ++itB) {
235#pragma unroll
236        for (int itA = 0; itA < ntA; ++itA) {
237#pragma unroll
238            for (int l = 0; l < tile_C::ne; ++l) {
239                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
240                const int j = itB*tile_C::J + tile_C::get_j(l);
241                buf_iw[j*kiw + i] = C[itA][itB].x[l];
242            }
243        }
244    }
245
246    if (nwarps > 1) {
247        __syncthreads();
248    }
249
250#pragma unroll
251    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
252        const int j = j0 + threadIdx.y;
253
254        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
255            return;
256        }
257
258        float sum[rows_per_block/warp_size] = {0.0f};
259        static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
260#pragma unroll
261        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
262#pragma unroll
263            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
264                const int i = i0 + i1*warp_size + threadIdx.x;
265
266                sum[i1] += buf_iw[j*kiw + i];
267            }
268        }
269
270        if constexpr (!has_ids) {
271#pragma unroll
272            for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
273                dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
274            }
275        } else {
276            const int slot = (j < cols_per_block) ? slot_map[j] : -1;
277            if (slot >= 0 && (col_base + j) < ncols_dst_total) {
278#pragma unroll
279                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
280                    dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
281                }
282            }
283        }
284    }
285    }
286#else
287    GGML_UNUSED_VARS(x, y, ids, dst,
288        ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
289        stride_col_id, stride_row_id,
290        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
291        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
292    NO_DEVICE_CODE;
293#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
294}
295
296//This kernel is for larger batch sizes of mul_mat_id
297template <typename T, int rows_per_block, int cols_per_block, int nwarps>
298__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
299static __global__ void mul_mat_f_ids(
300        const T * __restrict__ x, const float * __restrict__ y,
301        const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
302        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
303        const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
304        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
305        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
306        const uint3 sis1_fd, const uint3 nch_fd) {
307// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
308#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
309#if defined(AMD_WMMA_AVAILABLE)
310    if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
311    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;
312    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;
313    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;
314#elif defined(AMD_MFMA_AVAILABLE)
315    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
316    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;
317    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;
318    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
319#else
320#ifdef VOLTA_MMA_AVAILABLE
321    if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
322    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
323    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
324    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
325#else
326    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
327    typedef tile<16, 8, T>     tile_A;
328    typedef tile<8,  8, T>     tile_B;
329    typedef tile<16, 8, float> tile_C;
330#endif // VOLTA_MMA_AVAILABLE
331#endif // defined(AMD_WMMA_AVAILABLE)
332    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
333        NO_DEVICE_CODE;
334        return;
335    }
336
337
338    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
339    constexpr int tile_k_padded = warp_size + mmf_get_padding();
340    constexpr int ntA = rows_per_block / tile_A::I;
341    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
342
343    const int row0        = blockIdx.x * rows_per_block;
344
345    const int expert_idx = blockIdx.y;
346    const int expert_start = expert_bounds[expert_idx];
347    const int expert_end   = expert_bounds[expert_idx + 1];
348    const int ncols_expert = expert_end - expert_start;
349
350    const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
351    const int tile_idx = blockIdx.z;
352    if (tile_idx >= tiles_for_expert) {
353        return;
354    }
355
356    const int col_base = tile_idx * cols_per_block;
357
358    GGML_UNUSED(channel_ratio);
359
360    const int channel_x   = expert_idx;
361    const int sample_dst  = 0;
362    const int sample_x    = sample_dst / sample_ratio;
363    const int sample_y    = sample_dst;
364
365    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x  + row0*stride_row;
366    y   += int64_t(sample_y)  *stride_sample_y;
367    dst += int64_t(sample_dst)*stride_sample_dst;
368
369    const int32_t * ids_src_expert = ids_src_compact + expert_start;
370    const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
371
372    extern __shared__ char data_mmv[];
373    char * compute_base = data_mmv;
374
375    //const float2 * y2 = (const float2 *) y;
376
377    tile_C C[ntA][ntB];
378
379    T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
380
381    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
382        tile_A A[ntA][warp_size / tile_A::J];
383#pragma unroll
384        for (int itA = 0; itA < ntA; ++itA) {
385#pragma unroll
386            for (int i = 0; i < tile_A::I; ++i) {
387                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
388            }
389#pragma unroll
390            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
391                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
392            }
393        }
394
395        if constexpr (std::is_same_v<T, float>) {
396            float vals_buf[2][tile_B::I];
397            auto gather_tile = [&](int tile_idx_local, float *vals) {
398#pragma unroll
399                for (int j0 = 0; j0 < tile_B::I; ++j0) {
400                    const int j = j0 + tile_idx_local*tile_B::I;
401                    const int global_j = col_base + j;
402                    float val = 0.0f;
403                    if (j < cols_per_block && global_j < ncols_expert) {
404                        const int src_entry = ids_src_expert[global_j];
405                        const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
406                        const int token   = (int) qrm.x;
407                        const int channel = (int) qrm.y;
408                        if (token < ncols_dst_total) {
409                            val = y[channel*stride_channel_y + token*stride_col_y + col];
410                        }
411                    }
412                    vals[j0] = val;
413                }
414            };
415
416            gather_tile(0, vals_buf[0]);
417
418            int curr_buf = 0;
419            int next_buf = 1;
420#pragma unroll
421            for (int itB = 0; itB < ntB; ++itB) {
422#pragma unroll
423                for (int j0 = 0; j0 < tile_B::I; ++j0) {
424                    tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
425                }
426
427                if (itB + 1 < ntB) {
428                    gather_tile(itB + 1, vals_buf[next_buf]);
429                }
430
431#pragma unroll
432                for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
433                    tile_B B;
434                    load_ldmatrix(B, tile_xy + k0, tile_k_padded);
435#pragma unroll
436                    for (int itA = 0; itA < ntA; ++itA) {
437                        mma(C[itA][itB], A[itA][k0/tile_B::J], B);
438                    }
439                }
440
441                if (itB + 1 < ntB) {
442                    curr_buf ^= 1;
443                    next_buf ^= 1;
444                }
445            }
446        } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
447            float2 vals_buf[2][tile_B::I];
448            auto gather_tile = [&](int tile_idx_local, float2 *vals) {
449#pragma unroll
450                for (int j0 = 0; j0 < tile_B::I; ++j0) {
451                    const int j = j0 + tile_idx_local*tile_B::I;
452                    const int global_j = col_base + j;
453                    float2 tmp = make_float2(0.0f, 0.0f);
454                    if (j < cols_per_block && global_j < ncols_expert) {
455                        const int src_entry = ids_src_expert[global_j];
456                        const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
457                        const int token   = (int) qrm.x;
458                        const int channel = (int) qrm.y;
459                        if (token < ncols_dst_total) {
460                            tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
461                        }
462                    }
463                    vals[j0] = tmp;
464                }
465            };
466
467            if (ntB > 0) {
468                gather_tile(0, vals_buf[0]);
469            }
470
471            int curr_buf = 0;
472            int next_buf = 1;
473#pragma unroll
474            for (int itB = 0; itB < ntB; ++itB) {
475#pragma unroll
476                for (int j0 = 0; j0 < tile_B::I; ++j0) {
477                    const float2 tmp = vals_buf[curr_buf][j0];
478                    tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
479                }
480
481                if (itB + 1 < ntB) {
482                    gather_tile(itB + 1, vals_buf[next_buf]);
483                }
484
485#pragma unroll
486                for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
487                    tile_B B;
488                    load_ldmatrix(B, tile_xy + k0, tile_k_padded);
489#pragma unroll
490                    for (int itA = 0; itA < ntA; ++itA) {
491                        mma(C[itA][itB], A[itA][k0/tile_B::J], B);
492                    }
493                }
494
495                if (itB + 1 < ntB) {
496                    curr_buf ^= 1;
497                    next_buf ^= 1;
498                }
499            }
500        } else {
501            static_assert(std::is_same_v<T, void>, "unsupported type");
502        }
503    }
504
505    float * buf_iw = (float *) compute_base;
506    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
507
508    if (nwarps > 1) {
509        __syncthreads();
510    }
511#pragma unroll
512    for (int itB = 0; itB < ntB; ++itB) {
513#pragma unroll
514        for (int itA = 0; itA < ntA; ++itA) {
515#pragma unroll
516            for (int l = 0; l < tile_C::ne; ++l) {
517                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
518                const int j = itB*tile_C::J + tile_C::get_j(l);
519                buf_iw[j*kiw + i] = C[itA][itB].x[l];
520            }
521        }
522    }
523
524    if (nwarps > 1) {
525        __syncthreads();
526    }
527
528#pragma unroll
529    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
530        const int j = j0 + threadIdx.y;
531
532        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
533            return;
534        }
535
536        float sum[rows_per_block/warp_size] = {0.0f};
537        static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
538#pragma unroll
539        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
540#pragma unroll
541            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
542                const int i = i0 + i1*warp_size + threadIdx.x;
543
544                sum[i1] += buf_iw[j * kiw + i];
545            }
546        }
547
548        const int global_j = col_base + j;
549        if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
550            const int dst_entry = ids_dst_expert[global_j];
551            const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
552            const int token = (int) qrm.x;
553            if (token < ncols_dst_total) {
554                const int slot = (int) qrm.y;
555#pragma unroll
556                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
557                    dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
558                }
559            }
560        }
561    }
562    }
563#else
564    GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
565        ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
566        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
567        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
568    NO_DEVICE_CODE;
569#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
570}
571
572template<typename T, int rows_per_block, int cols_per_block, int nwarps>
573static inline void mul_mat_f_switch_ids(
574        const T * x, const float * y, const int32_t * ids, float * dst,
575        const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
576        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
577        const int64_t stride_col_id, const int64_t stride_row_id,
578        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
579        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
580        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
581        const mmf_ids_data * ids_data) {
582    const bool has_ids_data = ids_data && ids_data->ids_src_compact;
583
584    // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
585    // we prefer the normal mul_mat_f path with has_ids=true.
586    if (has_ids_data && ncols_dst > 16) {
587        const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
588        if (max_tiles == 0) {
589            return;
590        }
591        dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
592
593        const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
594        const uint3 nch_fd  = init_fastdiv_values((uint32_t) nchannels_dst);
595
596        mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
597            (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
598            ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
599            channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
600            sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
601            sis1_fd, nch_fd);
602    } else if (ids) {
603        const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
604        dim3 block_nums_ids = block_nums;
605        block_nums_ids.y *= col_tiles;
606
607        mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
608            (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
609             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
610             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
611    } else {
612        mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
613            (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
614             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
615             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
616    }
617}
618
619template <typename T, int rows_per_block, int cols_per_block>
620void mul_mat_f_cuda(
621        const T * x, const float * y, const int32_t * ids, float * dst,
622        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
623        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
624        const int64_t stride_col_id, const int64_t stride_row_id,
625        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
626        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
627        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
628        cudaStream_t stream, const mmf_ids_data * ids_data) {
629    typedef tile<16, 8, T>     tile_A_16;
630    typedef tile<32, 8, T>     tile_A_32;
631    typedef tile<16, 8, T>     tile_B_16;
632    typedef tile< 8, 8, T>     tile_B_8;
633
634    GGML_ASSERT(ncols_x      % 2 == 0);
635    GGML_ASSERT(stride_row   % 2 == 0);
636    GGML_ASSERT(stride_col_y % 2 == 0);
637    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
638    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
639    const int64_t channel_ratio = nchannels_dst / nchannels_x;
640    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
641
642    const int device    = ggml_cuda_get_device();
643    const int cc        = ggml_cuda_info().devices[device].cc;
644    const int warp_size = ggml_cuda_info().devices[device].warp_size;
645
646    int64_t nwarps_best     = 1;
647    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
648    int64_t max_block_size  = mmf_get_max_block_size(cc);
649    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
650        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
651        if (niter < niter_best) {
652            niter_best  = niter;
653            nwarps_best = nwarps;
654        }
655    }
656
657    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
658    const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
659    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
660    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
661    const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
662    const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
663    const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
664
665    const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
666    const dim3 block_dims(warp_size, nwarps_best, 1);
667
668    switch (nwarps_best) {
669        case 1: {
670            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
671                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
672                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
673                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
674                ids_data);
675        } break;
676        case 2: {
677            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
678                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
679                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
680                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
681                ids_data);
682        } break;
683        case 3: {
684            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
685                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
686                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
687                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
688                ids_data);
689        } break;
690        case 4: {
691            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
692                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
693                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
694                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
695                ids_data);
696        } break;
697        case 5: {
698            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
699                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
700                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
701                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
702                ids_data);
703        } break;
704        case 6: {
705            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
706                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
707                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
708                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
709                ids_data);
710        } break;
711        case 7: {
712            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
713                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
714                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
715                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
716                ids_data);
717        } break;
718        case 8: {
719            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
720                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
721                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
722                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
723                ids_data);
724        } break;
725        default: {
726            GGML_ABORT("fatal error");
727        } break;
728    }
729
730    GGML_UNUSED_VARS(nchannels_y);
731}
732
733template <typename T, int rows_per_block>
734static void mul_mat_f_switch_cols_per_block(
735        const T * x, const float * y, const int32_t * ids, float * dst,
736        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
737        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
738        const int64_t stride_col_id, const int stride_row_id,
739        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
740        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
741        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
742        cudaStream_t stream, const mmf_ids_data * ids_data) {
743
744    const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
745
746    GGML_ASSERT(ids || ncols_dst <= 16);
747
748    switch (ncols_case) {
749        case  1: {
750            mul_mat_f_cuda<T,  rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
751                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
752                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
753        } break;
754        case  2: {
755            mul_mat_f_cuda<T,  rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
756                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
757                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
758        } break;
759        case  3: {
760            mul_mat_f_cuda<T,  rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
761                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
762                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
763        } break;
764        case  4: {
765            mul_mat_f_cuda<T,  rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
766                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
767                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
768        } break;
769        case  5: {
770            mul_mat_f_cuda<T,  rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
771                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
772                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream, ids_data);
773        } break;
774        case  6: {
775            mul_mat_f_cuda<T,  rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
776                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
777                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
778        } break;
779        case  7: {
780            mul_mat_f_cuda<T,  rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
781                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
782                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
783        } break;
784        case  8: {
785            mul_mat_f_cuda<T,  rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
786                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
787                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
788        } break;
789        case  9: {
790            mul_mat_f_cuda<T,  rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
791                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
792                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
793        } break;
794        case 10: {
795            mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
796                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
797                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
798        } break;
799        case 11: {
800            mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
801                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
802                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
803        } break;
804        case 12: {
805            mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
806                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
807                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
808        } break;
809        case 13: {
810            mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
811                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
812                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
813        } break;
814        case 14: {
815            mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
816                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
817                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
818        } break;
819        case 15: {
820            mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
821                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
822                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
823        } break;
824        case 16: {
825            mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
826                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
827                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
828        } break;
829        default: {
830            GGML_ABORT("fatal error");
831        } break;
832    }
833}
834
835template <typename T>
836static void mul_mat_f_switch_rows_per_block(
837        const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
838        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
839        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
840        const int64_t stride_col_id, const int stride_row_id,
841        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
842        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
843        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
844        cudaStream_t stream, const mmf_ids_data * ids_data) {
845    switch (rows_per_block) {
846        case MMF_ROWS_PER_BLOCK: {
847            mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
848                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
849                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
850                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
851        } break;
852        case MMF_ROWS_PER_BLOCK_CDNA: {
853            mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
854                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
855                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
856                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
857        } break;
858        default:
859            GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
860    }
861}
862
863#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
864    template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
865        const T * x, const float * y, const int32_t * ids, float * dst, \
866        const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
867        const int64_t stride_col_id, const int64_t stride_row_id, \
868        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
869        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
870        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
871        cudaStream_t stream, const mmf_ids_data * ids_data);
872
873#if !defined(GGML_USE_MUSA)
874#define DECL_MMF_CASE_EXTERN(ncols_dst) \
875    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
876    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
877    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
878    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
879    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
880    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
881
882#define DECL_MMF_CASE(ncols_dst) \
883    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
884    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
885    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
886    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
887    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
888    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
889
890DECL_MMF_CASE_EXTERN(1);
891DECL_MMF_CASE_EXTERN(2);
892DECL_MMF_CASE_EXTERN(3);
893DECL_MMF_CASE_EXTERN(4);
894DECL_MMF_CASE_EXTERN(5);
895DECL_MMF_CASE_EXTERN(6);
896DECL_MMF_CASE_EXTERN(7);
897DECL_MMF_CASE_EXTERN(8);
898DECL_MMF_CASE_EXTERN(9);
899DECL_MMF_CASE_EXTERN(10);
900DECL_MMF_CASE_EXTERN(11);
901DECL_MMF_CASE_EXTERN(12);
902DECL_MMF_CASE_EXTERN(13);
903DECL_MMF_CASE_EXTERN(14);
904DECL_MMF_CASE_EXTERN(15);
905DECL_MMF_CASE_EXTERN(16);
906#else
907#define DECL_MMF_CASE(ncols_dst)
908#endif