1#pragma once
2
3#include "mma.cuh"
4#include "common.cuh"
5#include "convert.cuh"
6
7using namespace ggml_cuda_mma;
8
9#define MMF_ROWS_PER_BLOCK 32
10#define MMF_ROWS_PER_BLOCK_CDNA 64
11
12static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
13 if (GGML_CUDA_CC_IS_CDNA(cc)) {
14 return 512;
15 } else {
16 return 256;
17 }
18}
19
20static __forceinline__ int mmf_get_padding(int cc) {
21 if (GGML_CUDA_CC_IS_CDNA(cc)) {
22 return 2;
23 } else {
24 return 4;
25 }
26}
27
28static constexpr __device__ int mmf_get_padding() {
29#if defined(AMD_MFMA_AVAILABLE)
30 return 2;
31#else
32 return 4;
33#endif // defined(AMD_MFMA_AVAILABLE)
34}
35
36struct mmf_ids_data {
37 const int32_t * ids_src_compact = nullptr;
38 const int32_t * ids_dst_compact = nullptr;
39 const int32_t * expert_bounds_dev = nullptr;
40 int n_experts = 0;
41 int sis1 = 0;
42};
43
44void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
45
46bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);
47
48template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
49__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
50static __global__ void mul_mat_f(
51 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
52 const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
53 const int stride_col_id, const int stride_row_id,
54 const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
55 const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
56// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
57#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
58#if defined(AMD_WMMA_AVAILABLE)
59 if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
60 typedef tile<16, 8, T, get_input_data_layout()> tile_A;
61 typedef tile<16, 8, T, get_input_data_layout()> tile_B;
62 typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
63#elif defined(AMD_MFMA_AVAILABLE)
64 if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
65 typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
66 typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
67 typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
68#else
69#ifdef VOLTA_MMA_AVAILABLE
70 if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
71 typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
72 typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
73 typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
74#else
75 if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
76 typedef tile<16, 8, T> tile_A;
77 typedef tile<8, 8, T> tile_B;
78 typedef tile<16, 8, float> tile_C;
79#endif // VOLTA_MMA_AVAILABLE
80#endif // defined(AMD_WMMA_AVAILABLE)
81 if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
82 NO_DEVICE_CODE;
83 return;
84 }
85
86 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
87 constexpr int tile_k_padded = warp_size + mmf_get_padding();
88 constexpr int ntA = rows_per_block / tile_A::I;
89 constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
90
91 const int row0 = blockIdx.x * rows_per_block;
92
93 int expert_idx = 0;
94 int col_base = 0;
95
96 const int channel_dst = has_ids ? 0 : blockIdx.y;
97
98 if constexpr (has_ids) {
99 // experts + tiles of ncols_dst are packed in the y dimension
100 int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
101 const int nchannels_x = gridDim.y / col_tiles;
102 const int tile_idx = blockIdx.y / nchannels_x;
103 expert_idx = blockIdx.y - tile_idx * nchannels_x;
104 col_base = tile_idx * cols_per_block;
105 }
106
107 const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
108 const int channel_y = channel_dst;
109 const int sample_dst = blockIdx.z;
110 const int sample_x = sample_dst / sample_ratio;
111 const int sample_y = sample_dst;
112
113 x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
114 y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
115 dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
116
117 if constexpr (has_ids) {
118 constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
119 const int64_t col_offset = col_base;
120 y += col_offset * stride_col_y * y_stride_scale;
121 dst += col_offset * stride_col_dst;
122 ids += col_offset * stride_row_id;
123 }
124
125 const float2 * y2 = (const float2 *) y;
126
127 extern __shared__ char data_mmv[];
128
129 char * shmem_base = data_mmv;
130 int * slot_map = (int *) shmem_base;
131 char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
132
133 tile_C C[ntA][ntB];
134
135 T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
136
137 if constexpr (has_ids) {
138 int found = 0;
139
140 for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
141 const int j = j0 + threadIdx.y;
142
143 if (threadIdx.x == 0) {
144 slot_map[j] = -1;
145 }
146
147 if (col_base + j >= ncols_dst_total) {
148 continue;
149 }
150
151 const int32_t * __restrict__ id_row = ids + j*stride_row_id;
152
153 for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
154 int match = id_row[k*stride_col_id] == expert_idx;
155
156 if (match) {
157 slot_map[j] = k;
158 found = 1;
159 break;
160 }
161 }
162 }
163
164 if (!__syncthreads_or(found)) {
165 return;
166 }
167 }
168
169
170 for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
171 tile_A A[ntA][warp_size / tile_A::J];
172#pragma unroll
173 for (int itA = 0; itA < ntA; ++itA) {
174#pragma unroll
175 for (int i = 0; i < tile_A::I; ++i) {
176 tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
177 }
178#pragma unroll
179 for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
180 load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
181 }
182 }
183
184#pragma unroll
185 for (int itB = 0; itB < ntB; ++itB) {
186 if constexpr (std::is_same_v<T, float>) {
187#pragma unroll
188 for (int j0 = 0; j0 < tile_B::I; ++j0) {
189 const int j = j0 + itB*tile_B::I;
190
191 if constexpr (!has_ids) {
192 tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
193 } else {
194 const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
195 tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
196 }
197 }
198 } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
199#pragma unroll
200 for (int j0 = 0; j0 < tile_B::I; ++j0) {
201 const int j = j0 + itB*tile_B::I;
202
203 if constexpr (!has_ids) {
204 const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
205 tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
206 } else {
207 const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
208 float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
209 tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
210 }
211 }
212 } else {
213 static_assert(std::is_same_v<T, void>, "unsupported type");
214 }
215#pragma unroll
216 for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
217 tile_B B;
218 load_ldmatrix(B, tile_xy + k0, tile_k_padded);
219#pragma unroll
220 for (int itA = 0; itA < ntA; ++itA) {
221 mma(C[itA][itB], A[itA][k0/tile_B::J], B);
222 }
223 }
224 }
225 }
226
227 float * buf_iw = (float *) compute_base;
228 constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
229
230 if (nwarps > 1) {
231 __syncthreads();
232 }
233#pragma unroll
234 for (int itB = 0; itB < ntB; ++itB) {
235#pragma unroll
236 for (int itA = 0; itA < ntA; ++itA) {
237#pragma unroll
238 for (int l = 0; l < tile_C::ne; ++l) {
239 const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
240 const int j = itB*tile_C::J + tile_C::get_j(l);
241 buf_iw[j*kiw + i] = C[itA][itB].x[l];
242 }
243 }
244 }
245
246 if (nwarps > 1) {
247 __syncthreads();
248 }
249
250#pragma unroll
251 for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
252 const int j = j0 + threadIdx.y;
253
254 if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
255 return;
256 }
257
258 float sum[rows_per_block/warp_size] = {0.0f};
259 static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
260#pragma unroll
261 for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
262#pragma unroll
263 for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
264 const int i = i0 + i1*warp_size + threadIdx.x;
265
266 sum[i1] += buf_iw[j*kiw + i];
267 }
268 }
269
270 if constexpr (!has_ids) {
271#pragma unroll
272 for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
273 dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
274 }
275 } else {
276 const int slot = (j < cols_per_block) ? slot_map[j] : -1;
277 if (slot >= 0 && (col_base + j) < ncols_dst_total) {
278#pragma unroll
279 for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
280 dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
281 }
282 }
283 }
284 }
285 }
286#else
287 GGML_UNUSED_VARS(x, y, ids, dst,
288 ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
289 stride_col_id, stride_row_id,
290 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
291 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
292 NO_DEVICE_CODE;
293#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
294}
295
296//This kernel is for larger batch sizes of mul_mat_id
297template <typename T, int rows_per_block, int cols_per_block, int nwarps>
298__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
299static __global__ void mul_mat_f_ids(
300 const T * __restrict__ x, const float * __restrict__ y,
301 const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
302 const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
303 const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
304 const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
305 const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
306 const uint3 sis1_fd, const uint3 nch_fd) {
307// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
308#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
309#if defined(AMD_WMMA_AVAILABLE)
310 if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
311 typedef tile<16, 8, T, get_input_data_layout()> tile_A;
312 typedef tile<16, 8, T, get_input_data_layout()> tile_B;
313 typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
314#elif defined(AMD_MFMA_AVAILABLE)
315 if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
316 typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
317 typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
318 typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
319#else
320#ifdef VOLTA_MMA_AVAILABLE
321 if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
322 typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
323 typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
324 typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
325#else
326 if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
327 typedef tile<16, 8, T> tile_A;
328 typedef tile<8, 8, T> tile_B;
329 typedef tile<16, 8, float> tile_C;
330#endif // VOLTA_MMA_AVAILABLE
331#endif // defined(AMD_WMMA_AVAILABLE)
332 if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
333 NO_DEVICE_CODE;
334 return;
335 }
336
337
338 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
339 constexpr int tile_k_padded = warp_size + mmf_get_padding();
340 constexpr int ntA = rows_per_block / tile_A::I;
341 constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
342
343 const int row0 = blockIdx.x * rows_per_block;
344
345 const int expert_idx = blockIdx.y;
346 const int expert_start = expert_bounds[expert_idx];
347 const int expert_end = expert_bounds[expert_idx + 1];
348 const int ncols_expert = expert_end - expert_start;
349
350 const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
351 const int tile_idx = blockIdx.z;
352 if (tile_idx >= tiles_for_expert) {
353 return;
354 }
355
356 const int col_base = tile_idx * cols_per_block;
357
358 GGML_UNUSED(channel_ratio);
359
360 const int channel_x = expert_idx;
361 const int sample_dst = 0;
362 const int sample_x = sample_dst / sample_ratio;
363 const int sample_y = sample_dst;
364
365 x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
366 y += int64_t(sample_y) *stride_sample_y;
367 dst += int64_t(sample_dst)*stride_sample_dst;
368
369 const int32_t * ids_src_expert = ids_src_compact + expert_start;
370 const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
371
372 extern __shared__ char data_mmv[];
373 char * compute_base = data_mmv;
374
375 //const float2 * y2 = (const float2 *) y;
376
377 tile_C C[ntA][ntB];
378
379 T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
380
381 for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
382 tile_A A[ntA][warp_size / tile_A::J];
383#pragma unroll
384 for (int itA = 0; itA < ntA; ++itA) {
385#pragma unroll
386 for (int i = 0; i < tile_A::I; ++i) {
387 tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
388 }
389#pragma unroll
390 for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
391 load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
392 }
393 }
394
395 if constexpr (std::is_same_v<T, float>) {
396 float vals_buf[2][tile_B::I];
397 auto gather_tile = [&](int tile_idx_local, float *vals) {
398#pragma unroll
399 for (int j0 = 0; j0 < tile_B::I; ++j0) {
400 const int j = j0 + tile_idx_local*tile_B::I;
401 const int global_j = col_base + j;
402 float val = 0.0f;
403 if (j < cols_per_block && global_j < ncols_expert) {
404 const int src_entry = ids_src_expert[global_j];
405 const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
406 const int token = (int) qrm.x;
407 const int channel = (int) qrm.y;
408 if (token < ncols_dst_total) {
409 val = y[channel*stride_channel_y + token*stride_col_y + col];
410 }
411 }
412 vals[j0] = val;
413 }
414 };
415
416 gather_tile(0, vals_buf[0]);
417
418 int curr_buf = 0;
419 int next_buf = 1;
420#pragma unroll
421 for (int itB = 0; itB < ntB; ++itB) {
422#pragma unroll
423 for (int j0 = 0; j0 < tile_B::I; ++j0) {
424 tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
425 }
426
427 if (itB + 1 < ntB) {
428 gather_tile(itB + 1, vals_buf[next_buf]);
429 }
430
431#pragma unroll
432 for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
433 tile_B B;
434 load_ldmatrix(B, tile_xy + k0, tile_k_padded);
435#pragma unroll
436 for (int itA = 0; itA < ntA; ++itA) {
437 mma(C[itA][itB], A[itA][k0/tile_B::J], B);
438 }
439 }
440
441 if (itB + 1 < ntB) {
442 curr_buf ^= 1;
443 next_buf ^= 1;
444 }
445 }
446 } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
447 float2 vals_buf[2][tile_B::I];
448 auto gather_tile = [&](int tile_idx_local, float2 *vals) {
449#pragma unroll
450 for (int j0 = 0; j0 < tile_B::I; ++j0) {
451 const int j = j0 + tile_idx_local*tile_B::I;
452 const int global_j = col_base + j;
453 float2 tmp = make_float2(0.0f, 0.0f);
454 if (j < cols_per_block && global_j < ncols_expert) {
455 const int src_entry = ids_src_expert[global_j];
456 const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
457 const int token = (int) qrm.x;
458 const int channel = (int) qrm.y;
459 if (token < ncols_dst_total) {
460 tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
461 }
462 }
463 vals[j0] = tmp;
464 }
465 };
466
467 if (ntB > 0) {
468 gather_tile(0, vals_buf[0]);
469 }
470
471 int curr_buf = 0;
472 int next_buf = 1;
473#pragma unroll
474 for (int itB = 0; itB < ntB; ++itB) {
475#pragma unroll
476 for (int j0 = 0; j0 < tile_B::I; ++j0) {
477 const float2 tmp = vals_buf[curr_buf][j0];
478 tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
479 }
480
481 if (itB + 1 < ntB) {
482 gather_tile(itB + 1, vals_buf[next_buf]);
483 }
484
485#pragma unroll
486 for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
487 tile_B B;
488 load_ldmatrix(B, tile_xy + k0, tile_k_padded);
489#pragma unroll
490 for (int itA = 0; itA < ntA; ++itA) {
491 mma(C[itA][itB], A[itA][k0/tile_B::J], B);
492 }
493 }
494
495 if (itB + 1 < ntB) {
496 curr_buf ^= 1;
497 next_buf ^= 1;
498 }
499 }
500 } else {
501 static_assert(std::is_same_v<T, void>, "unsupported type");
502 }
503 }
504
505 float * buf_iw = (float *) compute_base;
506 constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
507
508 if (nwarps > 1) {
509 __syncthreads();
510 }
511#pragma unroll
512 for (int itB = 0; itB < ntB; ++itB) {
513#pragma unroll
514 for (int itA = 0; itA < ntA; ++itA) {
515#pragma unroll
516 for (int l = 0; l < tile_C::ne; ++l) {
517 const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
518 const int j = itB*tile_C::J + tile_C::get_j(l);
519 buf_iw[j*kiw + i] = C[itA][itB].x[l];
520 }
521 }
522 }
523
524 if (nwarps > 1) {
525 __syncthreads();
526 }
527
528#pragma unroll
529 for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
530 const int j = j0 + threadIdx.y;
531
532 if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
533 return;
534 }
535
536 float sum[rows_per_block/warp_size] = {0.0f};
537 static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
538#pragma unroll
539 for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
540#pragma unroll
541 for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
542 const int i = i0 + i1*warp_size + threadIdx.x;
543
544 sum[i1] += buf_iw[j * kiw + i];
545 }
546 }
547
548 const int global_j = col_base + j;
549 if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
550 const int dst_entry = ids_dst_expert[global_j];
551 const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
552 const int token = (int) qrm.x;
553 if (token < ncols_dst_total) {
554 const int slot = (int) qrm.y;
555#pragma unroll
556 for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
557 dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
558 }
559 }
560 }
561 }
562 }
563#else
564 GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
565 ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
566 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
567 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
568 NO_DEVICE_CODE;
569#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
570}
571
572template<typename T, int rows_per_block, int cols_per_block, int nwarps>
573static inline void mul_mat_f_switch_ids(
574 const T * x, const float * y, const int32_t * ids, float * dst,
575 const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
576 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
577 const int64_t stride_col_id, const int64_t stride_row_id,
578 const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
579 const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
580 const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
581 const mmf_ids_data * ids_data) {
582 const bool has_ids_data = ids_data && ids_data->ids_src_compact;
583
584 // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
585 // we prefer the normal mul_mat_f path with has_ids=true.
586 if (has_ids_data && ncols_dst > 16) {
587 const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
588 if (max_tiles == 0) {
589 return;
590 }
591 dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
592
593 const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
594 const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
595
596 mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
597 (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
598 ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
599 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
600 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
601 sis1_fd, nch_fd);
602 } else if (ids) {
603 const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
604 dim3 block_nums_ids = block_nums;
605 block_nums_ids.y *= col_tiles;
606
607 mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
608 (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
609 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
610 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
611 } else {
612 mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
613 (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
614 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
615 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
616 }
617}
618
619template <typename T, int rows_per_block, int cols_per_block>
620void mul_mat_f_cuda(
621 const T * x, const float * y, const int32_t * ids, float * dst,
622 const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
623 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
624 const int64_t stride_col_id, const int64_t stride_row_id,
625 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
626 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
627 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
628 cudaStream_t stream, const mmf_ids_data * ids_data) {
629 typedef tile<16, 8, T> tile_A_16;
630 typedef tile<32, 8, T> tile_A_32;
631 typedef tile<16, 8, T> tile_B_16;
632 typedef tile< 8, 8, T> tile_B_8;
633
634 GGML_ASSERT(ncols_x % 2 == 0);
635 GGML_ASSERT(stride_row % 2 == 0);
636 GGML_ASSERT(stride_col_y % 2 == 0);
637 GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
638 GGML_ASSERT( nsamples_dst % nsamples_x == 0);
639 const int64_t channel_ratio = nchannels_dst / nchannels_x;
640 const int64_t sample_ratio = nsamples_dst / nsamples_x;
641
642 const int device = ggml_cuda_get_device();
643 const int cc = ggml_cuda_info().devices[device].cc;
644 const int warp_size = ggml_cuda_info().devices[device].warp_size;
645
646 int64_t nwarps_best = 1;
647 int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
648 int64_t max_block_size = mmf_get_max_block_size(cc);
649 for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
650 const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
651 if (niter < niter_best) {
652 niter_best = niter;
653 nwarps_best = nwarps;
654 }
655 }
656
657 const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
658 const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
659 const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
660 const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
661 const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
662 const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
663 const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
664
665 const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
666 const dim3 block_dims(warp_size, nwarps_best, 1);
667
668 switch (nwarps_best) {
669 case 1: {
670 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
671 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
672 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
673 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
674 ids_data);
675 } break;
676 case 2: {
677 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
678 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
679 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
680 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
681 ids_data);
682 } break;
683 case 3: {
684 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
685 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
686 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
687 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
688 ids_data);
689 } break;
690 case 4: {
691 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
692 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
693 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
694 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
695 ids_data);
696 } break;
697 case 5: {
698 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
699 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
700 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
701 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
702 ids_data);
703 } break;
704 case 6: {
705 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
706 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
707 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
708 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
709 ids_data);
710 } break;
711 case 7: {
712 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
713 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
714 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
715 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
716 ids_data);
717 } break;
718 case 8: {
719 mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
720 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
721 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
722 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
723 ids_data);
724 } break;
725 default: {
726 GGML_ABORT("fatal error");
727 } break;
728 }
729
730 GGML_UNUSED_VARS(nchannels_y);
731}
732
733template <typename T, int rows_per_block>
734static void mul_mat_f_switch_cols_per_block(
735 const T * x, const float * y, const int32_t * ids, float * dst,
736 const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
737 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
738 const int64_t stride_col_id, const int stride_row_id,
739 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
740 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
741 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
742 cudaStream_t stream, const mmf_ids_data * ids_data) {
743
744 const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
745
746 GGML_ASSERT(ids || ncols_dst <= 16);
747
748 switch (ncols_case) {
749 case 1: {
750 mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
751 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
752 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
753 } break;
754 case 2: {
755 mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
756 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
757 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
758 } break;
759 case 3: {
760 mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
761 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
762 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
763 } break;
764 case 4: {
765 mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
766 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
767 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
768 } break;
769 case 5: {
770 mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
771 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
772 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
773 } break;
774 case 6: {
775 mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
776 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
777 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
778 } break;
779 case 7: {
780 mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
781 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
782 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
783 } break;
784 case 8: {
785 mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
786 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
787 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
788 } break;
789 case 9: {
790 mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
791 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
792 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
793 } break;
794 case 10: {
795 mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
796 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
797 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
798 } break;
799 case 11: {
800 mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
801 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
802 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
803 } break;
804 case 12: {
805 mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
806 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
807 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
808 } break;
809 case 13: {
810 mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
811 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
812 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
813 } break;
814 case 14: {
815 mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
816 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
817 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
818 } break;
819 case 15: {
820 mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
821 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
822 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
823 } break;
824 case 16: {
825 mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
826 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
827 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
828 } break;
829 default: {
830 GGML_ABORT("fatal error");
831 } break;
832 }
833}
834
835template <typename T>
836static void mul_mat_f_switch_rows_per_block(
837 const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
838 const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
839 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
840 const int64_t stride_col_id, const int stride_row_id,
841 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
842 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
843 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
844 cudaStream_t stream, const mmf_ids_data * ids_data) {
845 switch (rows_per_block) {
846 case MMF_ROWS_PER_BLOCK: {
847 mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
848 x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
849 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
850 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
851 } break;
852 case MMF_ROWS_PER_BLOCK_CDNA: {
853 mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
854 x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
855 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
856 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
857 } break;
858 default:
859 GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
860 }
861}
862
863#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
864 template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
865 const T * x, const float * y, const int32_t * ids, float * dst, \
866 const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
867 const int64_t stride_col_id, const int64_t stride_row_id, \
868 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
869 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
870 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
871 cudaStream_t stream, const mmf_ids_data * ids_data);
872
873#if !defined(GGML_USE_MUSA)
874#define DECL_MMF_CASE_EXTERN(ncols_dst) \
875 extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
876 extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
877 extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
878 extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
879 extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
880 extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
881
882#define DECL_MMF_CASE(ncols_dst) \
883 DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
884 DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
885 DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
886 DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
887 DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
888 DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
889
890DECL_MMF_CASE_EXTERN(1);
891DECL_MMF_CASE_EXTERN(2);
892DECL_MMF_CASE_EXTERN(3);
893DECL_MMF_CASE_EXTERN(4);
894DECL_MMF_CASE_EXTERN(5);
895DECL_MMF_CASE_EXTERN(6);
896DECL_MMF_CASE_EXTERN(7);
897DECL_MMF_CASE_EXTERN(8);
898DECL_MMF_CASE_EXTERN(9);
899DECL_MMF_CASE_EXTERN(10);
900DECL_MMF_CASE_EXTERN(11);
901DECL_MMF_CASE_EXTERN(12);
902DECL_MMF_CASE_EXTERN(13);
903DECL_MMF_CASE_EXTERN(14);
904DECL_MMF_CASE_EXTERN(15);
905DECL_MMF_CASE_EXTERN(16);
906#else
907#define DECL_MMF_CASE(ncols_dst)
908#endif