summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp213
1 files changed, 213 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
new file mode 100644
index 0000000..ef2f202
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -0,0 +1,213 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_shuffle : enable
+
+#include "types.glsl"
+
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
+layout (push_constant) uniform parameter
+{
+ uint n_rows;
+ uint n_experts_push;
+ uint n_expert_used;
+ float clamp_min;
+ float clamp_max;
+ uint gating_func;
+ uint has_bias;
+ uint with_norm;
+ float output_scale;
+ float output_bias;
+};
+
+layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
+
+layout(constant_id = 0) const uint WARP_SIZE = 32;
+layout(constant_id = 1) const uint n_experts_spec = 512;
+layout(constant_id = 2) const bool nexperts_use_push = false;
+
+uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
+
+layout (binding = 0, std430) readonly buffer Logits {float logits[];};
+layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
+layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
+layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
+
+const float INFINITY = 1.0 / 0.0;
+
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+ float max_val = -INFINITY;
+
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ max_val = max(max_val, vals[i]);
+ }
+ }
+
+ max_val = subgroupMax(max_val);
+
+ float sum = 0.f;
+
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ const float val = exp(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
+ }
+
+ sum = subgroupAdd(sum);
+
+ const float inv_sum = 1.0f / sum;
+
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ vals[i] *= inv_sum;
+ }
+ }
+}
+
+void main() {
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
+ if (row >= n_rows) {
+ return;
+ }
+
+ const uint logits_offset = n_experts * row;
+ const uint bias_offset = 0; // 1D
+ const uint weights_offset = n_expert_used * row;
+ const uint ids_offset = n_experts * row;
+ const uint lane = gl_SubgroupInvocationID;
+
+ float probs[experts_per_thread];
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ probs[i] = -INFINITY;
+ }
+
+ [[unroll]]
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + lane;
+ probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (gating_func == GATING_FUNC_SOFTMAX) {
+ softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
+ } else if (gating_func == GATING_FUNC_SIGMOID) {
+ [[unroll]]
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + lane;
+ probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
+ }
+ }
+
+ float selection_probs[experts_per_thread];
+ if (has_bias != 0) {
+ [[unroll]]
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + lane;
+ selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
+ }
+ } else {
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ selection_probs[i] = probs[i];
+ }
+ }
+
+ // at this point, each thread holds a portion of softmax,
+ // we do the argmax reduce over n_expert_used, each time marking
+ // the expert weight as -inf to exclude from the next iteration
+
+ float wt_sum = 0.f;
+
+ float output_weights[experts_per_thread];
+
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
+ for (int k = 0; k < n_expert_used; k++) {
+ float max_val = probs[0];
+ float max_val_s = selection_probs[0];
+ uint max_expert = lane;
+
+ [[unroll]]
+ for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + lane;
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
+ max_val = probs[i / WARP_SIZE];
+ max_val_s = selection_probs[i / WARP_SIZE];
+ max_expert = expert;
+ }
+ }
+
+ [[unroll]]
+ for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+ const float val = subgroupShuffleXor(max_val, mask);
+ const float val_s = subgroupShuffleXor(max_val_s, mask);
+ const uint expert = subgroupShuffleXor(max_expert, mask);
+ if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
+ max_val = val;
+ max_val_s = val_s;
+ max_expert = expert;
+ }
+ }
+
+ if ((k & (WARP_SIZE - 1)) == lane) {
+ output_weights[k / WARP_SIZE] = max_val;
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == lane) {
+ selection_probs[max_expert / WARP_SIZE] = -INFINITY;
+
+ ids[ids_offset + k] = max_expert;
+ wt_sum += max_val;
+ }
+ }
+
+ if (with_norm != 0) {
+ wt_sum = subgroupAdd(wt_sum);
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
+ const float inv_sum = 1.0f / wt_sum;
+
+ [[unroll]]
+ for (uint i = 0; i < experts_per_thread; ++i) {
+ output_weights[i] *= inv_sum;
+ }
+ }
+
+ if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
+ softmax_warp_inplace(output_weights, n_expert_used, lane, true);
+ }
+
+ [[unroll]]
+ for (uint i = 0; i < experts_per_thread; ++i) {
+ uint idx = i * WARP_SIZE + lane;
+ if (idx < n_expert_used) {
+ weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
+ }
+ }
+}