diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp new file mode 100644 index 0000000..ffc8608 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp @@ -0,0 +1,51 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint32_t ne00; + uint32_t ne01; + uint32_t nb00; + uint32_t nb01; + uint32_t a_offset; +} p; + +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {uint data_a[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +shared uint vals[BLOCK_SIZE]; + +void main() { + const uint expert_id = gl_WorkGroupID.x; + const uint num_elements = p.ne00 * p.ne01; + const uint tid = gl_LocalInvocationID.x; + + uint count = 0; + for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) { + const uint i01 = idx / p.ne00; + const uint i00 = idx % p.ne00; + const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00]; + + count += uint(a == expert_id); + } + + vals[tid] = count; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + data_d[expert_id] = vals[0]; + } +} |
