1#include "ggml.h"
2#include "mmf.cuh"
3#include "mmid.cuh"
4
5static __forceinline__ int mmf_get_rows_per_block(const int cc) {
6 if (GGML_CUDA_CC_IS_CDNA(cc)) {
7 return MMF_ROWS_PER_BLOCK_CDNA;
8 } else {
9 return MMF_ROWS_PER_BLOCK;
10 }
11}
12
13void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
14 GGML_ASSERT( src1->type == GGML_TYPE_F32);
15 GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
16 GGML_ASSERT( dst->type == GGML_TYPE_F32);
17
18
19 GGML_TENSOR_BINARY_OP_LOCALS;
20
21 const size_t ts_src0 = ggml_type_size(src0->type);
22 const size_t ts_src1 = ggml_type_size(src1->type);
23 const size_t ts_dst = ggml_type_size(dst->type);
24
25 GGML_ASSERT(ne13 == ne3);
26
27 GGML_ASSERT( nb00 == ts_src0);
28 GGML_ASSERT( nb10 == ts_src1);
29 GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
30 GGML_ASSERT( nb0 == ts_dst);
31
32 const float * src1_d = (const float *) src1->data;
33 const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
34 float * dst_d = (float *) dst->data;
35
36 const int64_t s01 = src0->nb[1] / ts_src0;
37 const int64_t s11 = src1->nb[1] / ts_src1;
38 const int64_t s1 = dst->nb[1] / ts_dst;
39 const int64_t s02 = src0->nb[2] / ts_src0;
40 const int64_t s12 = src1->nb[2] / ts_src1;
41 const int64_t s2 = dst->nb[2] / ts_dst;
42 const int64_t s03 = src0->nb[3] / ts_src0;
43 const int64_t s13 = src1->nb[3] / ts_src1;
44 const int64_t s3 = dst->nb[3] / ts_dst;
45
46 const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
47 const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
48
49 mmf_ids_data ids_info{};
50 mmf_ids_data * ids_info_ptr = nullptr;
51 ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
52 ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
53 ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
54
55 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
56 const int64_t ncols_dst = ids ? ne2 : ne1;
57 const int64_t nchannels_dst = ids ? ne1 : ne2;
58
59 const int64_t stride_col_dst = ids ? s2 : s1;
60 const int64_t stride_col_y = ids ? s12 : s11;
61 const int64_t stride_channel_dst = ids ? s1 : s2;
62
63 int64_t stride_channel_y = ids ? s11 : s12;
64 int64_t nchannels_y = ids ? ne11 : ne12;
65
66 //mul_mat_id: handle broadcast
67 if (ids && nchannels_y == 1) {
68 stride_channel_y = 0;
69 nchannels_y = ids->ne[0];
70 }
71
72 if (ids && ncols_dst > 16) {
73 const int64_t n_expert_used = ids->ne[0];
74 const int64_t n_experts = ne02;
75 const int64_t n_tokens = ne12;
76 const int64_t ne_get_rows = n_tokens * n_expert_used;
77
78 ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
79 ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
80 expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
81
82 const int si1 = static_cast<int>(ids_s1);
83 const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
84
85 GGML_ASSERT(sis1 > 0);
86
87 ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
88 static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
89 CUDA_CHECK(cudaGetLastError());
90
91 ids_info.ids_src_compact = ids_src_compact_dev.get();
92 ids_info.ids_dst_compact = ids_dst_compact_dev.get();
93 ids_info.expert_bounds_dev = expert_bounds_dev.get();
94 ids_info.n_experts = static_cast<int>(n_experts);
95 ids_info.sis1 = sis1;
96 ids_info_ptr = &ids_info;
97 }
98
99 const int device = ggml_cuda_get_device();
100 const int cc = ggml_cuda_info().devices[device].cc;
101 const int rows_per_block = mmf_get_rows_per_block(cc);
102
103 switch (src0->type) {
104 case GGML_TYPE_F32: {
105 const float * src0_d = (const float *) src0->data;
106 constexpr int vals_per_T = 1;
107 mul_mat_f_switch_rows_per_block<float>(
108 rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
109 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
110 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
111 } break;
112 case GGML_TYPE_F16: {
113 const half2 * src0_d = (const half2 *) src0->data;
114 constexpr int vals_per_T = 2;
115 mul_mat_f_switch_rows_per_block<half2>(
116 rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
117 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
118 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
119 } break;
120 case GGML_TYPE_BF16: {
121 const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
122 constexpr int vals_per_T = 2;
123 mul_mat_f_switch_rows_per_block<nv_bfloat162>(
124 rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
125 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
126 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
127 } break;
128 default:
129 GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
130 }
131}
132
133bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,
134 const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {
135 if (ggml_is_quantized(type)) {
136 return false;
137 }
138
139 const size_t ts = ggml_type_size(type);
140 if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
141 return false;
142 }
143
144 if (src0_nb[0] != ts) {
145 return false;
146 }
147
148 // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
149 for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
150 if (src0_nb[i] % (2*ts) != 0) {
151 return false;
152 }
153 }
154 if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
155 return false;
156 }
157
158 if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
159 return false;
160 }
161
162 if (mul_mat_id) {
163 if (src0_ne[1] <= 1024 && src1_ncols > 512) {
164 return false;
165 } else if(src0_ne[1] > 1024 && src1_ncols > 128) {
166 return false;
167 }
168 } else {
169 if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
170 return false;
171 } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
172 //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
173 return false;
174 } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
175 return false;
176 } else if (src1_ncols > 16) {
177 return false;
178 }
179 }
180
181 switch (type) {
182 case GGML_TYPE_F32:
183 return ampere_mma_available(cc) || amd_mfma_available(cc);
184 case GGML_TYPE_F16:
185 return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
186 case GGML_TYPE_BF16:
187 return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
188 default:
189 return false;
190 }
191}