1#ifdef MUL_MAT_ID
 2shared u16vec2 row_ids[BN];
 3uint _ne1;
 4
 5#ifdef MUL_MAT_ID_USE_SUBGROUPS
 6shared uvec4 ballots_sh[NUM_WARPS];
 7
 8void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
 9    _ne1 = 0;
10    uint num_elements = p.nei1 * p.nei0;
11    uint nei0shift = findLSB(p.nei0);
12
13    uint ids[16];
14    uint iter = 0;
15
16    uint expert_count = data_expert_count[expert_idx];
17
18    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
19        // prefetch up to 16 elements
20        if (iter == 0) {
21            [[unroll]] for (uint k = 0; k < 16; ++k) {
22                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
23                bool in_range = i < num_elements;
24                uint ii1;
25                if (nei0_is_pow2) {
26                    ii1 = i >> nei0shift;
27                } else {
28                    ii1 = i / p.nei0;
29                }
30                uint ii0 = i - ii1 * p.nei0;
31                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
32            }
33        }
34        uint i = j + gl_LocalInvocationIndex;
35        bool in_range = i < num_elements;
36        uint ii1;
37        if (nei0_is_pow2) {
38            ii1 = i >> nei0shift;
39        } else {
40            ii1 = i / p.nei0;
41        }
42        uint ii0 = i - ii1 * p.nei0;
43        uint id = ids[iter++];
44        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
45
46        ballots_sh[gl_SubgroupID] = ballot;
47        barrier();
48
49        uint subgroup_base = 0;
50        uint total = 0;
51        for (uint k = 0; k < gl_NumSubgroups; ++k) {
52            if (k == gl_SubgroupID) {
53                subgroup_base = total;
54            }
55            total += subgroupBallotBitCount(ballots_sh[k]);
56        }
57        barrier();
58
59        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
60        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
61            row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
62        }
63        _ne1 += total;
64        iter &= 15;
65        if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
66            break;
67        }
68    }
69    barrier();
70}
71#endif // MUL_MAT_ID_USE_SUBGROUPS
72#endif // MUL_MAT_ID