diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp new file mode 100644 index 0000000..ef2f202 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -0,0 +1,213 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + +layout (push_constant) uniform parameter +{ + uint n_rows; + uint n_experts_push; + uint n_expert_used; + float clamp_min; + float clamp_max; + uint gating_func; + uint has_bias; + uint with_norm; + float output_scale; + float output_bias; +}; + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 1) const uint n_experts_spec = 512; +layout(constant_id = 2) const bool nexperts_use_push = false; + +uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); + +layout (binding = 0, std430) readonly buffer Logits {float logits[];}; +layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];}; +layout (binding = 2, std430) writeonly buffer Weights {float weights[];}; +layout (binding = 3, std430) writeonly buffer Ids {uint ids[];}; + +const float INFINITY = 1.0 / 0.0; + +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) { + float max_val = -INFINITY; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = subgroupMax(max_val); + + float sum = 0.f; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + const float val = exp(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = subgroupAdd(sum); + + const float inv_sum = 1.0f / sum; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + vals[i] *= inv_sum; + } + } +} + +void main() { + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; + if (row >= n_rows) { + return; + } + + const uint logits_offset = n_experts * row; + const uint bias_offset = 0; // 1D + const uint weights_offset = n_expert_used * row; + const uint ids_offset = n_experts * row; + const uint lane = gl_SubgroupInvocationID; + + float probs[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = -INFINITY; + } + + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; + } + + if (gating_func == GATING_FUNC_SOFTMAX) { + softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); + } else if (gating_func == GATING_FUNC_SIGMOID) { + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY; + } + } + + float selection_probs[experts_per_thread]; + if (has_bias != 0) { + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY; + } + } else { + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + selection_probs[i] = probs[i]; + } + } + + // at this point, each thread holds a portion of softmax, + // we do the argmax reduce over n_expert_used, each time marking + // the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + float output_weights[experts_per_thread]; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] = 0.f; + } + + for (int k = 0; k < n_expert_used; k++) { + float max_val = probs[0]; + float max_val_s = selection_probs[0]; + uint max_expert = lane; + + [[unroll]] + for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) { + max_val = probs[i / WARP_SIZE]; + max_val_s = selection_probs[i / WARP_SIZE]; + max_expert = expert; + } + } + + [[unroll]] + for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = subgroupShuffleXor(max_val, mask); + const float val_s = subgroupShuffleXor(max_val_s, mask); + const uint expert = subgroupShuffleXor(max_expert, mask); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((k & (WARP_SIZE - 1)) == lane) { + output_weights[k / WARP_SIZE] = max_val; + } + + if ((max_expert & (WARP_SIZE - 1)) == lane) { + selection_probs[max_expert / WARP_SIZE] = -INFINITY; + + ids[ids_offset + k] = max_expert; + wt_sum += max_val; + } + } + + if (with_norm != 0) { + wt_sum = subgroupAdd(wt_sum); + wt_sum = clamp(wt_sum, clamp_min, clamp_max); + const float inv_sum = 1.0f / wt_sum; + + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + output_weights[i] *= inv_sum; + } + } + + if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { + softmax_warp_inplace(output_weights, n_expert_used, lane, true); + } + + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + uint idx = i * WARP_SIZE + lane; + if (idx < n_expert_used) { + weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias; + } + } +} |
