1#include "ggml-cuda/common.cuh"
2#include "ggml.h"
3#include "topk-moe.cuh"
4
5#include <cmath>
6#include <initializer_list>
7
8// Kernel config struct - passed by value to CUDA kernel
9struct topk_moe_config {
10 bool use_sigmoid;
11 bool with_norm;
12 bool delayed_softmax;
13};
14
15// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
16template <int experts_per_thread, bool use_limit>
17__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
18 float max_val = -INFINITY;
19
20#pragma unroll
21 for (int i = 0; i < experts_per_thread; i++) {
22 const int idx = lane + i * WARP_SIZE;
23 const bool active = !use_limit || (idx < limit);
24 if (active) {
25 max_val = max(max_val, vals[i]);
26 }
27 }
28
29 max_val = warp_reduce_max(max_val);
30
31 float sum = 0.f;
32
33#pragma unroll
34 for (int i = 0; i < experts_per_thread; i++) {
35 const int idx = lane + i * WARP_SIZE;
36 const bool active = !use_limit || (idx < limit);
37 if (active) {
38 const float val = expf(vals[i] - max_val);
39 vals[i] = val;
40 sum += val;
41 } else {
42 vals[i] = 0.f;
43 }
44 }
45
46 sum = warp_reduce_sum(sum);
47
48 const float inv_sum = 1.0f / sum;
49
50#pragma unroll
51 for (int i = 0; i < experts_per_thread; i++) {
52 const int idx = lane + i * WARP_SIZE;
53 const bool active = !use_limit || (idx < limit);
54 if (active) {
55 vals[i] *= inv_sum;
56 }
57 }
58}
59
60template <int experts_per_thread, bool use_limit>
61__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
62#pragma unroll
63 for (int i = 0; i < experts_per_thread; i++) {
64 const int idx = lane + i * WARP_SIZE;
65 const bool active = !use_limit || (idx < limit);
66 vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
67 }
68}
69
70/*
71 This kernel does the following:
72 1. optionally softmax over the logits per token [n_experts, n_tokens]
73 2. argmax reduce over the top-k (n_experts_used) logits
74 3. write weights + ids to global memory
75 4. optionally normalize the weights or apply softmax over the selected logits
76
77 It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
78*/
79template <int n_experts, bool has_bias>
80__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
81 float * weights,
82 int32_t * ids,
83 float * bias,
84 const int n_rows,
85 const int n_expert_used,
86 const float clamp_val,
87 const float scale_val,
88 const topk_moe_config config) {
89 const int row = blockIdx.x * blockDim.y + threadIdx.y;
90 if (row >= n_rows) {
91 return;
92 }
93
94 logits += n_experts * row;
95 weights += n_expert_used * row;
96 ids += n_experts * row;
97
98 constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
99
100 float wt[experts_per_thread];
101
102 // Initialize all slots to -INFINITY
103#pragma unroll
104 for (int i = 0; i < experts_per_thread; i++) {
105 wt[i] = -INFINITY;
106 }
107
108#pragma unroll
109 for (int i = 0; i < n_experts; i += WARP_SIZE) {
110 const int expert = i + threadIdx.x;
111 wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
112 }
113
114 if (!config.delayed_softmax) {
115 if (config.use_sigmoid) {
116 sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
117 } else {
118 softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
119 }
120 }
121
122 // selection_wt is only needed when bias is present (selection uses wt + bias)
123 // when no bias, we use wt directly for both selection and weight values
124 float selection_wt[has_bias ? experts_per_thread : 1];
125
126 if constexpr (has_bias) {
127#pragma unroll
128 for (int i = 0; i < experts_per_thread; i++) {
129 selection_wt[i] = -INFINITY;
130 }
131#pragma unroll
132 for (int i = 0; i < n_experts; i += WARP_SIZE) {
133 const int expert = i + threadIdx.x;
134 selection_wt[i / WARP_SIZE] =
135 (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
136 }
137 }
138
139 //at this point, each thread holds either a portion of the softmax distribution
140 //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
141 //the expert weight as -inf to exclude from the next iteration
142
143 float wt_sum = 0.f;
144
145 float output_weights[experts_per_thread];
146
147#pragma unroll
148 for (int i = 0; i < experts_per_thread; i++) {
149 output_weights[i] = 0.f;
150 }
151
152 for (int k = 0; k < n_expert_used; k++) {
153 float max_val = wt[0];
154 int max_expert = threadIdx.x;
155
156 if constexpr (has_bias) {
157 float max_val_s = selection_wt[0];
158
159#pragma unroll
160 for (int i = 1; i < experts_per_thread; i++) {
161 const int expert = threadIdx.x + i * WARP_SIZE;
162 if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
163 max_val = wt[i];
164 max_val_s = selection_wt[i];
165 max_expert = expert;
166 }
167 }
168
169#pragma unroll
170 for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
171 const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
172 const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
173 const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
174 if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
175 max_val = val;
176 max_val_s = val_s;
177 max_expert = expert;
178 }
179 }
180
181 if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
182 selection_wt[max_expert / WARP_SIZE] = -INFINITY;
183 }
184 } else {
185#pragma unroll
186 for (int i = 1; i < experts_per_thread; i++) {
187 const int expert = threadIdx.x + i * WARP_SIZE;
188 if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
189 max_val = wt[i];
190 max_expert = expert;
191 }
192 }
193
194#pragma unroll
195 for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
196 const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
197 const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
198 if (val > max_val || (val == max_val && expert < max_expert)) {
199 max_val = val;
200 max_expert = expert;
201 }
202 }
203
204 if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
205 wt[max_expert / WARP_SIZE] = -INFINITY;
206 }
207 }
208
209 if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
210 output_weights[k / WARP_SIZE] = max_val;
211 }
212
213 if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
214 ids[k] = max_expert;
215 if (config.with_norm) {
216 wt_sum += max_val;
217 }
218 }
219 }
220
221 if (config.with_norm) {
222 wt_sum = warp_reduce_sum(wt_sum);
223 wt_sum = max(wt_sum, clamp_val);
224 const float inv_sum = 1.0f / wt_sum;
225
226 for (int i = 0; i < experts_per_thread; i++) {
227 output_weights[i] *= inv_sum;
228 }
229 }
230
231 if (config.delayed_softmax) {
232 softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
233 }
234
235#pragma unroll
236 for (int i = 0; i < experts_per_thread; i++) {
237 const int idx = i * WARP_SIZE + threadIdx.x;
238 if (idx < n_expert_used) {
239 weights[idx] = output_weights[i] * scale_val;
240 }
241 }
242}
243
244template<bool has_bias>
245static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
246 const float * logits,
247 float * weights,
248 int32_t * ids,
249 float * bias,
250 const int n_rows,
251 const int n_expert,
252 const int n_expert_used,
253 const float clamp_val,
254 const float scale_val,
255 const topk_moe_config config) {
256 GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
257 "delayed softmax is not supported with weight normalization");
258 const int rows_per_block = 4;
259 dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
260 dim3 block_dims(WARP_SIZE, rows_per_block, 1);
261 cudaStream_t stream = ctx.stream();
262
263 switch (n_expert) {
264 case 1:
265 topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
266 clamp_val, scale_val, config);
267 break;
268 case 2:
269 topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
270 clamp_val, scale_val, config);
271 break;
272 case 4:
273 topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
274 clamp_val, scale_val, config);
275 break;
276 case 8:
277 topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
278 clamp_val, scale_val, config);
279 break;
280 case 16:
281 topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
282 clamp_val, scale_val, config);
283 break;
284 case 32:
285 topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
286 clamp_val, scale_val, config);
287 break;
288 case 64:
289 topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
290 clamp_val, scale_val, config);
291 break;
292 case 128:
293 topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
294 clamp_val, scale_val, config);
295 break;
296 case 256:
297 topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
298 clamp_val, scale_val, config);
299 break;
300 case 512:
301 topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
302 clamp_val, scale_val, config);
303 break;
304 case 576:
305 topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
306 clamp_val, scale_val, config);
307 break;
308 default:
309 GGML_ASSERT(false && "fatal error");
310 break;
311 }
312}
313
314void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
315 const ggml_tensor * logits,
316 ggml_tensor * weights,
317 ggml_tensor * ids,
318 const ggml_tensor * clamp,
319 const ggml_tensor * scale,
320 const ggml_tensor * bias,
321 const ggml_cuda_topk_moe_args & args) {
322 GGML_ASSERT(logits->type == GGML_TYPE_F32);
323 GGML_ASSERT(weights->type == GGML_TYPE_F32);
324 GGML_ASSERT(ids->type == GGML_TYPE_I32);
325
326 const int n_experts = logits->ne[0];
327 const int n_rows = logits->ne[1];
328
329 const float * logits_d = (const float *) logits->data;
330 float * weights_d = (float *) weights->data;
331 int32_t * ids_d = (int32_t *) ids->data;
332 float * bias_d = bias ? (float *) bias->data : nullptr;
333
334 float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
335
336 GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
337
338 const int n_expert_used = weights->ne[1];
339
340 const bool with_norm = clamp != nullptr;
341
342 float clamp_val = -INFINITY;
343 if (clamp) {
344 clamp_val = ggml_get_op_params_f32(clamp, 0);
345 }
346
347 topk_moe_config config;
348 config.use_sigmoid = args.sigmoid;
349 config.with_norm = with_norm;
350 config.delayed_softmax = args.delayed_softmax;
351
352 if (bias) {
353 launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
354 scale_val, config);
355 } else {
356 launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
357 scale_val, config);
358 }
359}
360
361bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
362 const ggml_tensor * weights,
363 const ggml_tensor * logits,
364 const ggml_tensor * ids) {
365 const int n_expert = ids->nb[1] / ids->nb[0];
366 if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
367 return false;
368 }
369
370 if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
371 return false;
372 }
373
374 if (gating_op->op == GGML_OP_SOFT_MAX) {
375 const ggml_tensor * softmax = gating_op;
376 float scale = 1.0f;
377 float max_bias = 0.0f;
378
379 memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
380 memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
381
382 if (!ggml_is_contiguous(softmax->src[0])) {
383 return false;
384 }
385
386 if (scale != 1.0f || max_bias != 0.0f) {
387 return false;
388 }
389
390 // don't fuse when masks or sinks are present
391 if (softmax->src[1] || softmax->src[2]) {
392 return false;
393 }
394 } else if (gating_op->op == GGML_OP_UNARY) {
395 ggml_unary_op op = ggml_get_unary_op(gating_op);
396
397 if (op != GGML_UNARY_OP_SIGMOID) {
398 return false;
399 }
400 }
401
402 return true;
403}