1#include "common.cuh"
2#include "ggml.h"
3
4#include <initializer_list>
5
6struct ggml_cuda_topk_moe_args {
7 bool sigmoid{};
8 bool softmax{};
9 bool delayed_softmax{};
10 bool prob_bias{};
11 bool norm{};
12 bool scale{};
13};
14
15void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
16 const ggml_tensor * logits,
17 ggml_tensor * weights,
18 ggml_tensor * ids,
19 const ggml_tensor * clamp,
20 const ggml_tensor * scale,
21 const ggml_tensor * bias,
22 const ggml_cuda_topk_moe_args & args);
23
24bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
25 const ggml_tensor * weights,
26 const ggml_tensor * logits,
27 const ggml_tensor * ids);