summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/topk-moe.cu')
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/topk-moe.cu403
1 files changed, 403 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu
new file mode 100644
index 0000000..08a8899
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu
@@ -0,0 +1,403 @@
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+#include "topk-moe.cuh"
+
+#include <cmath>
+#include <initializer_list>
+
+// Kernel config struct - passed by value to CUDA kernel
+struct topk_moe_config {
+ bool use_sigmoid;
+ bool with_norm;
+ bool delayed_softmax;
+};
+
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+template <int experts_per_thread, bool use_limit>
+__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+ float max_val = -INFINITY;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ max_val = max(max_val, vals[i]);
+ }
+ }
+
+ max_val = warp_reduce_max(max_val);
+
+ float sum = 0.f;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ const float val = expf(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
+ }
+
+ sum = warp_reduce_sum(sum);
+
+ const float inv_sum = 1.0f / sum;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ vals[i] *= inv_sum;
+ }
+ }
+}
+
+template <int experts_per_thread, bool use_limit>
+__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
+ }
+}
+
+/*
+ This kernel does the following:
+ 1. optionally softmax over the logits per token [n_experts, n_tokens]
+ 2. argmax reduce over the top-k (n_experts_used) logits
+ 3. write weights + ids to global memory
+ 4. optionally normalize the weights or apply softmax over the selected logits
+
+ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
+*/
+template <int n_experts, bool has_bias>
+__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
+ float * weights,
+ int32_t * ids,
+ float * bias,
+ const int n_rows,
+ const int n_expert_used,
+ const float clamp_val,
+ const float scale_val,
+ const topk_moe_config config) {
+ const int row = blockIdx.x * blockDim.y + threadIdx.y;
+ if (row >= n_rows) {
+ return;
+ }
+
+ logits += n_experts * row;
+ weights += n_expert_used * row;
+ ids += n_experts * row;
+
+ constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
+
+ float wt[experts_per_thread];
+
+ // Initialize all slots to -INFINITY
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ wt[i] = -INFINITY;
+ }
+
+#pragma unroll
+ for (int i = 0; i < n_experts; i += WARP_SIZE) {
+ const int expert = i + threadIdx.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
+ }
+
+ if (!config.delayed_softmax) {
+ if (config.use_sigmoid) {
+ sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+ } else {
+ softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+ }
+ }
+
+ // selection_wt is only needed when bias is present (selection uses wt + bias)
+ // when no bias, we use wt directly for both selection and weight values
+ float selection_wt[has_bias ? experts_per_thread : 1];
+
+ if constexpr (has_bias) {
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ selection_wt[i] = -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_experts; i += WARP_SIZE) {
+ const int expert = i + threadIdx.x;
+ selection_wt[i / WARP_SIZE] =
+ (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
+ }
+ }
+
+ //at this point, each thread holds either a portion of the softmax distribution
+ //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
+ //the expert weight as -inf to exclude from the next iteration
+
+ float wt_sum = 0.f;
+
+ float output_weights[experts_per_thread];
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
+ for (int k = 0; k < n_expert_used; k++) {
+ float max_val = wt[0];
+ int max_expert = threadIdx.x;
+
+ if constexpr (has_bias) {
+ float max_val_s = selection_wt[0];
+
+#pragma unroll
+ for (int i = 1; i < experts_per_thread; i++) {
+ const int expert = threadIdx.x + i * WARP_SIZE;
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
+ max_val = wt[i];
+ max_val_s = selection_wt[i];
+ max_expert = expert;
+ }
+ }
+
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+ const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+ const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
+ const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+ if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
+ max_val = val;
+ max_val_s = val_s;
+ max_expert = expert;
+ }
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+ selection_wt[max_expert / WARP_SIZE] = -INFINITY;
+ }
+ } else {
+#pragma unroll
+ for (int i = 1; i < experts_per_thread; i++) {
+ const int expert = threadIdx.x + i * WARP_SIZE;
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+ max_val = wt[i];
+ max_expert = expert;
+ }
+ }
+
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+ const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+ const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+ if (val > max_val || (val == max_val && expert < max_expert)) {
+ max_val = val;
+ max_expert = expert;
+ }
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+ wt[max_expert / WARP_SIZE] = -INFINITY;
+ }
+ }
+
+ if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
+ output_weights[k / WARP_SIZE] = max_val;
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+ ids[k] = max_expert;
+ if (config.with_norm) {
+ wt_sum += max_val;
+ }
+ }
+ }
+
+ if (config.with_norm) {
+ wt_sum = warp_reduce_sum(wt_sum);
+ wt_sum = max(wt_sum, clamp_val);
+ const float inv_sum = 1.0f / wt_sum;
+
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] *= inv_sum;
+ }
+ }
+
+ if (config.delayed_softmax) {
+ softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
+ }
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = i * WARP_SIZE + threadIdx.x;
+ if (idx < n_expert_used) {
+ weights[idx] = output_weights[i] * scale_val;
+ }
+ }
+}
+
+template<bool has_bias>
+static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
+ const float * logits,
+ float * weights,
+ int32_t * ids,
+ float * bias,
+ const int n_rows,
+ const int n_expert,
+ const int n_expert_used,
+ const float clamp_val,
+ const float scale_val,
+ const topk_moe_config config) {
+ GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
+ "delayed softmax is not supported with weight normalization");
+ const int rows_per_block = 4;
+ dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
+ dim3 block_dims(WARP_SIZE, rows_per_block, 1);
+ cudaStream_t stream = ctx.stream();
+
+ switch (n_expert) {
+ case 1:
+ topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 2:
+ topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 4:
+ topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 8:
+ topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 16:
+ topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 32:
+ topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 64:
+ topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 128:
+ topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 256:
+ topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 512:
+ topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 576:
+ topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ default:
+ GGML_ASSERT(false && "fatal error");
+ break;
+ }
+}
+
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
+ const ggml_tensor * logits,
+ ggml_tensor * weights,
+ ggml_tensor * ids,
+ const ggml_tensor * clamp,
+ const ggml_tensor * scale,
+ const ggml_tensor * bias,
+ const ggml_cuda_topk_moe_args & args) {
+ GGML_ASSERT(logits->type == GGML_TYPE_F32);
+ GGML_ASSERT(weights->type == GGML_TYPE_F32);
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ const int n_experts = logits->ne[0];
+ const int n_rows = logits->ne[1];
+
+ const float * logits_d = (const float *) logits->data;
+ float * weights_d = (float *) weights->data;
+ int32_t * ids_d = (int32_t *) ids->data;
+ float * bias_d = bias ? (float *) bias->data : nullptr;
+
+ float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
+
+ GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
+
+ const int n_expert_used = weights->ne[1];
+
+ const bool with_norm = clamp != nullptr;
+
+ float clamp_val = -INFINITY;
+ if (clamp) {
+ clamp_val = ggml_get_op_params_f32(clamp, 0);
+ }
+
+ topk_moe_config config;
+ config.use_sigmoid = args.sigmoid;
+ config.with_norm = with_norm;
+ config.delayed_softmax = args.delayed_softmax;
+
+ if (bias) {
+ launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+ scale_val, config);
+ } else {
+ launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+ scale_val, config);
+ }
+}
+
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
+ const ggml_tensor * weights,
+ const ggml_tensor * logits,
+ const ggml_tensor * ids) {
+ const int n_expert = ids->nb[1] / ids->nb[0];
+ if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
+ return false;
+ }
+
+ if (gating_op->op == GGML_OP_SOFT_MAX) {
+ const ggml_tensor * softmax = gating_op;
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
+
+ if (!ggml_is_contiguous(softmax->src[0])) {
+ return false;
+ }
+
+ if (scale != 1.0f || max_bias != 0.0f) {
+ return false;
+ }
+
+ // don't fuse when masks or sinks are present
+ if (softmax->src[1] || softmax->src[2]) {
+ return false;
+ }
+ } else if (gating_op->op == GGML_OP_UNARY) {
+ ggml_unary_op op = ggml_get_unary_op(gating_op);
+
+ if (op != GGML_UNARY_OP_SIGMOID) {
+ return false;
+ }
+ }
+
+ return true;
+}