1#ifdef MUL_MAT_ID
2shared u16vec2 row_ids[BN];
3uint _ne1;
4
5#ifdef MUL_MAT_ID_USE_SUBGROUPS
6shared uvec4 ballots_sh[NUM_WARPS];
7
8void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
9 _ne1 = 0;
10 uint num_elements = p.nei1 * p.nei0;
11 uint nei0shift = findLSB(p.nei0);
12
13 uint ids[16];
14 uint iter = 0;
15
16 uint expert_count = data_expert_count[expert_idx];
17
18 for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
19 // prefetch up to 16 elements
20 if (iter == 0) {
21 [[unroll]] for (uint k = 0; k < 16; ++k) {
22 uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
23 bool in_range = i < num_elements;
24 uint ii1;
25 if (nei0_is_pow2) {
26 ii1 = i >> nei0shift;
27 } else {
28 ii1 = i / p.nei0;
29 }
30 uint ii0 = i - ii1 * p.nei0;
31 ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
32 }
33 }
34 uint i = j + gl_LocalInvocationIndex;
35 bool in_range = i < num_elements;
36 uint ii1;
37 if (nei0_is_pow2) {
38 ii1 = i >> nei0shift;
39 } else {
40 ii1 = i / p.nei0;
41 }
42 uint ii0 = i - ii1 * p.nei0;
43 uint id = ids[iter++];
44 uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
45
46 ballots_sh[gl_SubgroupID] = ballot;
47 barrier();
48
49 uint subgroup_base = 0;
50 uint total = 0;
51 for (uint k = 0; k < gl_NumSubgroups; ++k) {
52 if (k == gl_SubgroupID) {
53 subgroup_base = total;
54 }
55 total += subgroupBallotBitCount(ballots_sh[k]);
56 }
57 barrier();
58
59 uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
60 if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
61 row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
62 }
63 _ne1 += total;
64 iter &= 15;
65 if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
66 break;
67 }
68 }
69 barrier();
70}
71#endif // MUL_MAT_ID_USE_SUBGROUPS
72#endif // MUL_MAT_ID