1#include "ggml-cuda/common.cuh"
  2#include "ggml.h"
  3#include "topk-moe.cuh"
  4
  5#include <cmath>
  6#include <initializer_list>
  7
  8// Kernel config struct - passed by value to CUDA kernel
  9struct topk_moe_config {
 10    bool use_sigmoid;
 11    bool with_norm;
 12    bool delayed_softmax;
 13};
 14
 15// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
 16template <int experts_per_thread, bool use_limit>
 17__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
 18    float max_val = -INFINITY;
 19
 20#pragma unroll
 21    for (int i = 0; i < experts_per_thread; i++) {
 22        const int  idx    = lane + i * WARP_SIZE;
 23        const bool active = !use_limit || (idx < limit);
 24        if (active) {
 25            max_val = max(max_val, vals[i]);
 26        }
 27    }
 28
 29    max_val = warp_reduce_max(max_val);
 30
 31    float sum = 0.f;
 32
 33#pragma unroll
 34    for (int i = 0; i < experts_per_thread; i++) {
 35        const int  idx    = lane + i * WARP_SIZE;
 36        const bool active = !use_limit || (idx < limit);
 37        if (active) {
 38            const float val = expf(vals[i] - max_val);
 39            vals[i]         = val;
 40            sum += val;
 41        } else {
 42            vals[i] = 0.f;
 43        }
 44    }
 45
 46    sum = warp_reduce_sum(sum);
 47
 48    const float inv_sum = 1.0f / sum;
 49
 50#pragma unroll
 51    for (int i = 0; i < experts_per_thread; i++) {
 52        const int  idx    = lane + i * WARP_SIZE;
 53        const bool active = !use_limit || (idx < limit);
 54        if (active) {
 55            vals[i] *= inv_sum;
 56        }
 57    }
 58}
 59
 60template <int experts_per_thread, bool use_limit>
 61__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
 62#pragma unroll
 63    for (int i = 0; i < experts_per_thread; i++) {
 64        const int  idx    = lane + i * WARP_SIZE;
 65        const bool active = !use_limit || (idx < limit);
 66        vals[i]           = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
 67    }
 68}
 69
 70/*
 71    This kernel does the following:
 72    1. optionally softmax over the logits per token [n_experts, n_tokens]
 73    2. argmax reduce over the top-k (n_experts_used) logits
 74    3. write weights + ids to global memory
 75    4. optionally normalize the weights or apply softmax over the selected logits
 76
 77    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
 78*/
 79template <int n_experts, bool has_bias>
 80__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *         logits,
 81                                                                  float *               weights,
 82                                                                  int32_t *             ids,
 83                                                                  float *               bias,
 84                                                                  const int             n_rows,
 85                                                                  const int             n_expert_used,
 86                                                                  const float           clamp_val,
 87                                                                  const float           scale_val,
 88                                                                  const topk_moe_config config) {
 89    const int row = blockIdx.x * blockDim.y + threadIdx.y;
 90    if (row >= n_rows) {
 91        return;
 92    }
 93
 94    logits += n_experts * row;
 95    weights += n_expert_used * row;
 96    ids += n_experts * row;
 97
 98    constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
 99
100    float wt[experts_per_thread];
101
102    // Initialize all slots to -INFINITY
103#pragma unroll
104    for (int i = 0; i < experts_per_thread; i++) {
105        wt[i] = -INFINITY;
106    }
107
108#pragma unroll
109    for (int i = 0; i < n_experts; i += WARP_SIZE) {
110        const int expert  = i + threadIdx.x;
111        wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
112    }
113
114    if (!config.delayed_softmax) {
115        if (config.use_sigmoid) {
116           sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
117        } else {
118           softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
119        }
120    }
121
122    // selection_wt is only needed when bias is present (selection uses wt + bias)
123    // when no bias, we use wt directly for both selection and weight values
124    float selection_wt[has_bias ? experts_per_thread : 1];
125
126    if constexpr (has_bias) {
127#pragma unroll
128        for (int i = 0; i < experts_per_thread; i++) {
129            selection_wt[i] = -INFINITY;
130        }
131#pragma unroll
132        for (int i = 0; i < n_experts; i += WARP_SIZE) {
133            const int expert = i + threadIdx.x;
134            selection_wt[i / WARP_SIZE] =
135                (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
136        }
137    }
138
139    //at this point, each thread holds either a portion of the softmax distribution
140    //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
141    //the expert weight as -inf to exclude from the next iteration
142
143    float wt_sum = 0.f;
144
145    float output_weights[experts_per_thread];
146
147#pragma unroll
148    for (int i = 0; i < experts_per_thread; i++) {
149        output_weights[i] = 0.f;
150    }
151
152    for (int k = 0; k < n_expert_used; k++) {
153        float max_val    = wt[0];
154        int   max_expert = threadIdx.x;
155
156        if constexpr (has_bias) {
157            float max_val_s = selection_wt[0];
158
159#pragma unroll
160            for (int i = 1; i < experts_per_thread; i++) {
161                const int expert = threadIdx.x + i * WARP_SIZE;
162                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
163                    max_val    = wt[i];
164                    max_val_s  = selection_wt[i];
165                    max_expert = expert;
166                }
167            }
168
169#pragma unroll
170            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
171                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
172                const float val_s  = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
173                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
174                if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
175                    max_val    = val;
176                    max_val_s  = val_s;
177                    max_expert = expert;
178                }
179            }
180
181            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
182                selection_wt[max_expert / WARP_SIZE] = -INFINITY;
183            }
184        } else {
185#pragma unroll
186            for (int i = 1; i < experts_per_thread; i++) {
187                const int expert = threadIdx.x + i * WARP_SIZE;
188                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
189                    max_val    = wt[i];
190                    max_expert = expert;
191                }
192            }
193
194#pragma unroll
195            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
196                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
197                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
198                if (val > max_val || (val == max_val && expert < max_expert)) {
199                    max_val    = val;
200                    max_expert = expert;
201                }
202            }
203
204            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
205                wt[max_expert / WARP_SIZE] = -INFINITY;
206            }
207        }
208
209        if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
210            output_weights[k / WARP_SIZE] = max_val;
211        }
212
213        if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
214            ids[k] = max_expert;
215            if (config.with_norm) {
216                wt_sum += max_val;
217            }
218        }
219    }
220
221    if (config.with_norm) {
222        wt_sum              = warp_reduce_sum(wt_sum);
223        wt_sum              = max(wt_sum, clamp_val);
224        const float inv_sum = 1.0f / wt_sum;
225
226        for (int i = 0; i < experts_per_thread; i++) {
227            output_weights[i] *= inv_sum;
228        }
229    }
230
231    if (config.delayed_softmax) {
232        softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
233    }
234
235#pragma unroll
236    for (int i = 0; i < experts_per_thread; i++) {
237        const int idx = i * WARP_SIZE + threadIdx.x;
238        if (idx < n_expert_used) {
239            weights[idx] = output_weights[i] * scale_val;
240        }
241    }
242}
243
244template<bool has_bias>
245static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
246                                 const float *               logits,
247                                 float *                     weights,
248                                 int32_t *                   ids,
249                                 float *                     bias,
250                                 const int                   n_rows,
251                                 const int                   n_expert,
252                                 const int                   n_expert_used,
253                                 const float                 clamp_val,
254                                 const float                 scale_val,
255                                 const topk_moe_config       config) {
256    GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
257                "delayed softmax is not supported with weight normalization");
258    const int    rows_per_block = 4;
259    dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
260    dim3         block_dims(WARP_SIZE, rows_per_block, 1);
261    cudaStream_t stream = ctx.stream();
262
263    switch (n_expert) {
264        case 1:
265            topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
266                                                                   clamp_val, scale_val, config);
267            break;
268        case 2:
269            topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
270                                                                   clamp_val, scale_val, config);
271            break;
272        case 4:
273            topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
274                                                                   clamp_val, scale_val, config);
275            break;
276        case 8:
277            topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
278                                                                   clamp_val, scale_val, config);
279            break;
280        case 16:
281            topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
282                                                                    clamp_val, scale_val, config);
283            break;
284        case 32:
285            topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
286                                                                    clamp_val, scale_val, config);
287            break;
288        case 64:
289            topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
290                                                                    clamp_val, scale_val, config);
291            break;
292        case 128:
293            topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
294                                                                     clamp_val, scale_val, config);
295            break;
296        case 256:
297            topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
298                                                                     clamp_val, scale_val, config);
299            break;
300        case 512:
301            topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
302                                                                     clamp_val, scale_val, config);
303            break;
304        case 576:
305            topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
306                                                                     clamp_val, scale_val, config);
307            break;
308        default:
309            GGML_ASSERT(false && "fatal error");
310            break;
311    }
312}
313
314void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
315                           const ggml_tensor *             logits,
316                           ggml_tensor *                   weights,
317                           ggml_tensor *                   ids,
318                           const ggml_tensor *             clamp,
319                           const ggml_tensor *             scale,
320                           const ggml_tensor *             bias,
321                           const ggml_cuda_topk_moe_args & args) {
322    GGML_ASSERT(logits->type == GGML_TYPE_F32);
323    GGML_ASSERT(weights->type == GGML_TYPE_F32);
324    GGML_ASSERT(ids->type == GGML_TYPE_I32);
325
326    const int n_experts = logits->ne[0];
327    const int n_rows    = logits->ne[1];
328
329    const float * logits_d  = (const float *) logits->data;
330    float *       weights_d = (float *) weights->data;
331    int32_t *     ids_d     = (int32_t *) ids->data;
332    float *       bias_d    = bias ? (float *) bias->data : nullptr;
333
334    float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
335
336    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
337
338    const int n_expert_used = weights->ne[1];
339
340    const bool with_norm = clamp != nullptr;
341
342    float clamp_val = -INFINITY;
343    if (clamp) {
344        clamp_val = ggml_get_op_params_f32(clamp, 0);
345    }
346
347    topk_moe_config config;
348    config.use_sigmoid     = args.sigmoid;
349    config.with_norm       = with_norm;
350    config.delayed_softmax = args.delayed_softmax;
351
352    if (bias) {
353        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
354                             scale_val, config);
355    } else {
356        launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
357                             scale_val, config);
358    }
359}
360
361bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
362                                   const ggml_tensor * weights,
363                                   const ggml_tensor * logits,
364                                   const ggml_tensor * ids) {
365    const int n_expert = ids->nb[1] / ids->nb[0];
366    if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
367        return false;
368    }
369
370    if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
371        return false;
372    }
373
374    if (gating_op->op == GGML_OP_SOFT_MAX) {
375        const ggml_tensor * softmax  = gating_op;
376        float               scale    = 1.0f;
377        float               max_bias = 0.0f;
378
379        memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
380        memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
381
382        if (!ggml_is_contiguous(softmax->src[0])) {
383            return false;
384        }
385
386        if (scale != 1.0f || max_bias != 0.0f) {
387            return false;
388        }
389
390        // don't fuse when masks or sinks are present
391        if (softmax->src[1] || softmax->src[2]) {
392            return false;
393        }
394    } else if (gating_op->op == GGML_OP_UNARY) {
395        ggml_unary_op op = ggml_get_unary_op(gating_op);
396
397        if (op != GGML_UNARY_OP_SIGMOID) {
398            return false;
399        }
400    }
401
402    return true;
403}