1#version 450
2
3#extension GL_EXT_control_flow_attributes : enable
4
5#include "types.glsl"
6
7layout (push_constant) uniform parameter
8{
9 uint32_t ne00;
10 uint32_t ne01;
11 uint32_t nb00;
12 uint32_t nb01;
13 uint32_t a_offset;
14} p;
15
16#define BLOCK_SIZE 256
17
18layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
19
20layout (binding = 0) readonly buffer A {uint data_a[];};
21layout (binding = 1) writeonly buffer D {uint data_d[];};
22
23shared uint vals[BLOCK_SIZE];
24
25void main() {
26 const uint expert_id = gl_WorkGroupID.x;
27 const uint num_elements = p.ne00 * p.ne01;
28 const uint tid = gl_LocalInvocationID.x;
29
30 uint count = 0;
31 for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
32 const uint i01 = idx / p.ne00;
33 const uint i00 = idx % p.ne00;
34 const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
35
36 count += uint(a == expert_id);
37 }
38
39 vals[tid] = count;
40 barrier();
41 [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
42 if (tid < s) {
43 vals[tid] += vals[tid + s];
44 }
45 barrier();
46 }
47
48 if (tid == 0) {
49 data_d[expert_id] = vals[0];
50 }
51}