1#version 450
 2
 3#extension GL_EXT_control_flow_attributes : enable
 4
 5#include "types.glsl"
 6
 7layout (push_constant) uniform parameter
 8{
 9    uint32_t ne00;
10    uint32_t ne01;
11    uint32_t nb00;
12    uint32_t nb01;
13    uint32_t a_offset;
14} p;
15
16#define BLOCK_SIZE 256
17
18layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
19
20layout (binding = 0) readonly buffer A {uint data_a[];};
21layout (binding = 1) writeonly buffer D {uint data_d[];};
22
23shared uint vals[BLOCK_SIZE];
24
25void main() {
26    const uint expert_id = gl_WorkGroupID.x;
27    const uint num_elements = p.ne00 * p.ne01;
28    const uint tid = gl_LocalInvocationID.x;
29
30    uint count = 0;
31    for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
32        const uint i01 = idx / p.ne00;
33        const uint i00 = idx % p.ne00;
34        const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
35
36        count += uint(a == expert_id);
37    }
38
39    vals[tid] = count;
40    barrier();
41    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
42        if (tid < s) {
43            vals[tid] += vals[tid + s];
44        }
45        barrier();
46    }
47
48    if (tid == 0) {
49        data_d[expert_id] = vals[0];
50    }
51}