summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp51
1 files changed, 51 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
new file mode 100644
index 0000000..ffc8608
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
@@ -0,0 +1,51 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.glsl"
+
+layout (push_constant) uniform parameter
+{
+ uint32_t ne00;
+ uint32_t ne01;
+ uint32_t nb00;
+ uint32_t nb01;
+ uint32_t a_offset;
+} p;
+
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {uint data_a[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+shared uint vals[BLOCK_SIZE];
+
+void main() {
+ const uint expert_id = gl_WorkGroupID.x;
+ const uint num_elements = p.ne00 * p.ne01;
+ const uint tid = gl_LocalInvocationID.x;
+
+ uint count = 0;
+ for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
+ const uint i01 = idx / p.ne00;
+ const uint i00 = idx % p.ne00;
+ const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
+
+ count += uint(a == expert_id);
+ }
+
+ vals[tid] = count;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ vals[tid] += vals[tid + s];
+ }
+ barrier();
+ }
+
+ if (tid == 0) {
+ data_d[expert_id] = vals[0];
+ }
+}