summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh')
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh27
1 files changed, 27 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh
new file mode 100644
index 0000000..243dc2f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh
@@ -0,0 +1,27 @@
1#include "common.cuh"
2#include "ggml.h"
3
4#include <initializer_list>
5
6struct ggml_cuda_topk_moe_args {
7 bool sigmoid{};
8 bool softmax{};
9 bool delayed_softmax{};
10 bool prob_bias{};
11 bool norm{};
12 bool scale{};
13};
14
15void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
16 const ggml_tensor * logits,
17 ggml_tensor * weights,
18 ggml_tensor * ids,
19 const ggml_tensor * clamp,
20 const ggml_tensor * scale,
21 const ggml_tensor * bias,
22 const ggml_cuda_topk_moe_args & args);
23
24bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
25 const ggml_tensor * weights,
26 const ggml_tensor * logits,
27 const ggml_tensor * ids);