diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda/topk-moe.cu | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/topk-moe.cu')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cuda/topk-moe.cu | 403 |
1 files changed, 403 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu new file mode 100644 index 0000000..08a8899 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu @@ -0,0 +1,403 @@ +#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "topk-moe.cuh" + +#include <cmath> +#include <initializer_list> + +// Kernel config struct - passed by value to CUDA kernel +struct topk_moe_config { + bool use_sigmoid; + bool with_norm; + bool delayed_softmax; +}; + +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +template <int experts_per_thread, bool use_limit> +__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { + float max_val = -INFINITY; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = warp_reduce_max(max_val); + + float sum = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + const float val = expf(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = warp_reduce_sum(sum); + + const float inv_sum = 1.0f / sum; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + vals[i] *= inv_sum; + } + } +} + +template <int experts_per_thread, bool use_limit> +__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY; + } +} + +/* + This kernel does the following: + 1. optionally softmax over the logits per token [n_experts, n_tokens] + 2. argmax reduce over the top-k (n_experts_used) logits + 3. write weights + ids to global memory + 4. optionally normalize the weights or apply softmax over the selected logits + + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models +*/ +template <int n_experts, bool has_bias> +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_expert_used * row; + ids += n_experts * row; + + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + + float wt[experts_per_thread]; + + // Initialize all slots to -INFINITY +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = -INFINITY; + } + +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; + } + + if (!config.delayed_softmax) { + if (config.use_sigmoid) { + sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + } else { + softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + } + } + + // selection_wt is only needed when bias is present (selection uses wt + bias) + // when no bias, we use wt directly for both selection and weight values + float selection_wt[has_bias ? experts_per_thread : 1]; + + if constexpr (has_bias) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + selection_wt[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + selection_wt[i / WARP_SIZE] = + (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY; + } + } + + //at this point, each thread holds either a portion of the softmax distribution + //or the raw logits. We do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + float output_weights[experts_per_thread]; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] = 0.f; + } + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + int max_expert = threadIdx.x; + + if constexpr (has_bias) { + float max_val_s = selection_wt[0]; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) { + max_val = wt[i]; + max_val_s = selection_wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + selection_wt[max_expert / WARP_SIZE] = -INFINITY; + } + } else { +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + } + } + + if ((k & (WARP_SIZE - 1)) == threadIdx.x) { + output_weights[k / WARP_SIZE] = max_val; + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + ids[k] = max_expert; + if (config.with_norm) { + wt_sum += max_val; + } + } + } + + if (config.with_norm) { + wt_sum = warp_reduce_sum(wt_sum); + wt_sum = max(wt_sum, clamp_val); + const float inv_sum = 1.0f / wt_sum; + + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] *= inv_sum; + } + } + + if (config.delayed_softmax) { + softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x); + } + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = i * WARP_SIZE + threadIdx.x; + if (idx < n_expert_used) { + weights[idx] = output_weights[i] * scale_val; + } + } +} + +template<bool has_bias> +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, + const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + GGML_ASSERT(!(config.with_norm && config.delayed_softmax) && + "delayed softmax is not supported with weight normalization"); + const int rows_per_block = 4; + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); + dim3 block_dims(WARP_SIZE, rows_per_block, 1); + cudaStream_t stream = ctx.stream(); + + switch (n_expert) { + case 1: + topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 2: + topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 4: + topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 8: + topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 16: + topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 32: + topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 64: + topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 128: + topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 256: + topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 512: + topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 576: + topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + default: + GGML_ASSERT(false && "fatal error"); + break; + } +} + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args) { + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + + const float * logits_d = (const float *) logits->data; + float * weights_d = (float *) weights->data; + int32_t * ids_d = (int32_t *) ids->data; + float * bias_d = bias ? (float *) bias->data : nullptr; + + float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + const int n_expert_used = weights->ne[1]; + + const bool with_norm = clamp != nullptr; + + float clamp_val = -INFINITY; + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + + topk_moe_config config; + config.use_sigmoid = args.sigmoid; + config.with_norm = with_norm; + config.delayed_softmax = args.delayed_softmax; + + if (bias) { + launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); + } else { + launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); + } +} + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, + const ggml_tensor * weights, + const ggml_tensor * logits, + const ggml_tensor * ids) { + const int n_expert = ids->nb[1] / ids->nb[0]; + if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) { + return false; + } + + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) { + return false; + } + + if (gating_op->op == GGML_OP_SOFT_MAX) { + const ggml_tensor * softmax = gating_op; + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (!ggml_is_contiguous(softmax->src[0])) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + } else if (gating_op->op == GGML_OP_UNARY) { + ggml_unary_op op = ggml_get_unary_op(gating_op); + + if (op != GGML_UNARY_OP_SIGMOID) { + return false; + } + } + + return true; +} |
