1#version 450
2#extension GL_EXT_control_flow_attributes : enable
3#extension GL_EXT_debug_printf : enable
4#extension GL_KHR_shader_subgroup_basic : enable
5#extension GL_KHR_shader_subgroup_ballot : enable
6#extension GL_KHR_shader_subgroup_arithmetic : enable
7#extension GL_KHR_shader_subgroup_shuffle : enable
8
9#include "types.glsl"
10
11layout(constant_id = 0) const int BLOCK_SIZE = 1024;
12layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
13layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
14
15layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
16
17// Input can either be the source (A) or intermediate values (S).
18// Similarly, output can be either destination (D) or intermediate values (S).
19layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
20layout (binding = 0) readonly buffer S {ivec2 data_s[];};
21layout (binding = 1) writeonly buffer D {int data_d[];};
22layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
23
24layout (push_constant) uniform parameter {
25 uint orig_ncols;
26 uint ncols_input;
27 uint ncols_output;
28 uint k;
29 uint nrows;
30 uint first_pass;
31 uint last_pass;
32} p;
33
34// pairs of (gid, value)
35shared ivec2 dst_row[BLOCK_SIZE];
36
37shared int counts[SUBGROUP_SIZE];
38shared int sh_min_idx;
39shared uint sh_total;
40shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
41shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
42
43// Map float values to uint such that comparisons still work.
44// Positive values set the high bit, negative values are inverted.
45// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
46uint f2ui(float x) {
47 uint y = floatBitsToUint(x);
48 if ((y & 0x80000000) != 0) {
49 y ^= ~0;
50 } else {
51 y |= 0x80000000;
52 }
53 return y;
54}
55
56void topk(const uint row) {
57 const int tid = int(gl_LocalInvocationID.x);
58
59 // initialize indices
60 if (gl_GlobalInvocationID.x < p.ncols_input) {
61 if (p.first_pass != 0) {
62 const uint row_offset = row * p.ncols_input;
63 dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
64 } else {
65 const uint row_offset = row * p.ncols_input;
66 dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
67 }
68 } else {
69 dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
70 }
71 barrier();
72
73 if (p.k == 1) {
74 // Fast path for single output - just do a max reduction
75 [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
76 if (tid < s) {
77 ivec2 a = dst_row[tid];
78 ivec2 b = dst_row[tid + s];
79 if (a.x >= p.orig_ncols ||
80 b.x < p.orig_ncols && b.y > a.y) {
81 dst_row[tid] = b;
82 }
83 }
84 barrier();
85 }
86 } else {
87 // Do an N-ary search to find the K-th largest value.
88 // We remap the float values to be comparable as unsigned integers,
89 // and split the range into 2^N smaller ranges where N is the
90 // subgroup size. Count how many values are in each range, if the K-th
91 // largest value is in the middle of one of thee ranges then repeat
92 // and split again.
93
94 // Mask is the current set of bits we're searching. Shift is the LSB index.
95 int shift = 32 - SUBGROUP_SIZE_LOG2;
96 uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
97
98 // The current range.
99 uint range_min = 0;
100 uint range_max = 0xFF800000;
101 // How many are above the current range, and how many we need to find.
102 uint total = 0;
103 uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
104
105 while (mask != 0) {
106 barrier();
107 // Initialize bucket counts to zero.
108 if (tid < SUBGROUP_SIZE) {
109 counts[tid] = 0;
110 }
111 barrier();
112 // Count how many values are in each bucket.
113 if (tid < p.ncols_input) {
114 float y = intBitsToFloat(dst_row[tid].y);
115 uint fy = f2ui(y);
116 if (fy >= range_min && fy < range_max) {
117 uint bucket = (fy & mask) >> shift;
118 atomicAdd(counts[bucket], 1);
119 }
120 }
121 barrier();
122
123 // On the first subgroup, do a scan to count (from the top down) how
124 // many elements are in the top N buckets. Find the index of the first
125 // that is over the limit. Copy it to the other invocations through
126 // shared memory.
127 if (tid < SUBGROUP_SIZE) {
128 uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
129 partial_sum = subgroupInclusiveAdd(partial_sum) + total;
130 uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
131 if (tid == t) {
132 sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
133 sh_total = partial_sum;
134 }
135 }
136 barrier();
137 int min_idx = sh_min_idx;
138 total = sh_total;
139
140 // Update the range, and break if we've found the K-th largest.
141 range_max = range_min + ((min_idx + 1) << shift);
142 range_min = range_min + (min_idx << shift);
143
144 if (total == p.k) {
145 break;
146 }
147 total -= counts[min_idx];
148 mask >>= SUBGROUP_SIZE_LOG2;
149 shift -= SUBGROUP_SIZE_LOG2;
150 if (shift < 0) {
151 shift = 0;
152 }
153 }
154
155 ivec2 v = dst_row[tid];
156
157 // We need to compact these values to the start of the dst_row array.
158 // Have each subgroup count how many items it'll store, so other
159 // subgroups can compute their base offset.
160 // Values strictly greater than range_min must be stored. For values equal
161 // to range_min, there can be ties and it's possible we'll need to store
162 // an arbitrary subset of them.
163 // If total == p.k, have a fast path where we don't need to handle ties.
164 if (total == p.k) {
165 bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
166 uvec4 b = subgroupBallot(top);
167 uint bit_count = subgroupBallotBitCount(b);
168 if ((tid % SUBGROUP_SIZE) == 0) {
169 offset_partials[tid / SUBGROUP_SIZE] = bit_count;
170 }
171 barrier();
172
173 uint out_idx = 0;
174 [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
175 if (i < tid / SUBGROUP_SIZE) {
176 out_idx += offset_partials[i];
177 }
178 }
179
180 uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
181 if (top) {
182 // TODO: Copy directly to the output?
183 dst_row[out_idx + bit_count_ex] = v;
184 }
185 } else {
186 bool top = f2ui(intBitsToFloat(v.y)) > range_min;
187 bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
188 uvec4 b_top = subgroupBallot(top);
189 uvec4 b_eq_min = subgroupBallot(eq_min);
190 uint bit_count_top = subgroupBallotBitCount(b_top);
191 uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
192 if ((tid % SUBGROUP_SIZE) == 0) {
193 offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
194 eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
195 }
196 barrier();
197
198 uint out_idx = 0;
199 uint eq_min_base = 0;
200 uint eq_min_idx = 0;
201 [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
202 if (i < tid / SUBGROUP_SIZE) {
203 out_idx += offset_partials[i];
204 eq_min_idx += eq_min_partials[i];
205 }
206 eq_min_base += offset_partials[i];
207 }
208 // range_min values are stored at the end
209 eq_min_idx += eq_min_base;
210
211 uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
212 uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
213 if (top) {
214 // TODO: Copy directly to the output?
215 dst_row[out_idx + bit_count_ex_top] = v;
216 }
217 if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
218 dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
219 }
220 }
221
222 barrier();
223 }
224
225 if (tid < p.k) {
226 if (p.last_pass != 0) {
227 if (gl_GlobalInvocationID.x < p.ncols_input) {
228 const uint row_offset = row * p.k;
229 data_d[row_offset + tid] = dst_row[tid].x;
230 }
231 } else {
232 if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
233 const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
234 data_t[row_offset + tid] = dst_row[tid];
235 }
236 }
237 }
238}
239
240void main() {
241 uint row = gl_WorkGroupID.y;
242 while (row < p.nrows) {
243 topk(row);
244 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
245 }
246}