1#version 450
  2#extension GL_EXT_control_flow_attributes : enable
  3#extension GL_EXT_debug_printf : enable
  4#extension GL_KHR_shader_subgroup_basic : enable
  5#extension GL_KHR_shader_subgroup_ballot : enable
  6#extension GL_KHR_shader_subgroup_arithmetic : enable
  7#extension GL_KHR_shader_subgroup_shuffle : enable
  8
  9#include "types.glsl"
 10
 11layout(constant_id = 0) const int BLOCK_SIZE = 1024;
 12layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
 13layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
 14
 15layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 16
 17// Input can either be the source (A) or intermediate values (S).
 18// Similarly, output can be either destination (D) or intermediate values (S).
 19layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 20layout (binding = 0) readonly buffer S {ivec2 data_s[];};
 21layout (binding = 1) writeonly buffer D {int data_d[];};
 22layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
 23
 24layout (push_constant) uniform parameter {
 25    uint orig_ncols;
 26    uint ncols_input;
 27    uint ncols_output;
 28    uint k;
 29    uint nrows;
 30    uint first_pass;
 31    uint last_pass;
 32} p;
 33
 34// pairs of (gid, value)
 35shared ivec2 dst_row[BLOCK_SIZE];
 36
 37shared int counts[SUBGROUP_SIZE];
 38shared int sh_min_idx;
 39shared uint sh_total;
 40shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
 41shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
 42
 43// Map float values to uint such that comparisons still work.
 44// Positive values set the high bit, negative values are inverted.
 45// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
 46uint f2ui(float x) {
 47    uint y = floatBitsToUint(x);
 48    if ((y & 0x80000000) != 0) {
 49        y ^= ~0;
 50    } else {
 51        y |= 0x80000000;
 52    }
 53    return y;
 54}
 55
 56void topk(const uint row) {
 57    const int tid = int(gl_LocalInvocationID.x);
 58
 59    // initialize indices
 60    if (gl_GlobalInvocationID.x < p.ncols_input) {
 61        if (p.first_pass != 0) {
 62            const uint row_offset = row * p.ncols_input;
 63            dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
 64        } else {
 65            const uint row_offset = row * p.ncols_input;
 66            dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
 67        }
 68    } else {
 69        dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
 70    }
 71    barrier();
 72
 73    if (p.k == 1) {
 74        // Fast path for single output - just do a max reduction
 75        [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
 76            if (tid < s) {
 77                ivec2 a = dst_row[tid];
 78                ivec2 b = dst_row[tid + s];
 79                if (a.x >= p.orig_ncols ||
 80                    b.x < p.orig_ncols && b.y > a.y) {
 81                    dst_row[tid] = b;
 82                }
 83            }
 84            barrier();
 85        }
 86    } else {
 87        // Do an N-ary search to find the K-th largest value.
 88        // We remap the float values to be comparable as unsigned integers,
 89        // and split the range into 2^N smaller ranges where N is the
 90        // subgroup size. Count how many values are in each range, if the K-th
 91        // largest value is in the middle of one of thee ranges then repeat
 92        // and split again.
 93
 94        // Mask is the current set of bits we're searching. Shift is the LSB index.
 95        int shift = 32 - SUBGROUP_SIZE_LOG2;
 96        uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
 97
 98        // The current range.
 99        uint range_min = 0;
100        uint range_max = 0xFF800000;
101        // How many are above the current range, and how many we need to find.
102        uint total = 0;
103        uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
104
105        while (mask != 0) {
106            barrier();
107            // Initialize bucket counts to zero.
108            if (tid < SUBGROUP_SIZE) {
109                counts[tid] = 0;
110            }
111            barrier();
112            // Count how many values are in each bucket.
113            if (tid < p.ncols_input) {
114                float y = intBitsToFloat(dst_row[tid].y);
115                uint fy = f2ui(y);
116                if (fy >= range_min && fy < range_max) {
117                    uint bucket = (fy & mask) >> shift;
118                    atomicAdd(counts[bucket], 1);
119                }
120            }
121            barrier();
122
123            // On the first subgroup, do a scan to count (from the top down) how
124            // many elements are in the top N buckets. Find the index of the first
125            // that is over the limit. Copy it to the other invocations through
126            // shared memory.
127            if (tid < SUBGROUP_SIZE) {
128                uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
129                partial_sum = subgroupInclusiveAdd(partial_sum) + total;
130                uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
131                if (tid == t) {
132                    sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
133                    sh_total = partial_sum;
134                }
135            }
136            barrier();
137            int min_idx = sh_min_idx;
138            total = sh_total;
139
140            // Update the range, and break if we've found the K-th largest.
141            range_max = range_min + ((min_idx + 1) << shift);
142            range_min = range_min + (min_idx << shift);
143
144            if (total == p.k) {
145                break;
146            }
147            total -= counts[min_idx];
148            mask >>= SUBGROUP_SIZE_LOG2;
149            shift -= SUBGROUP_SIZE_LOG2;
150            if (shift < 0) {
151                shift = 0;
152            }
153        }
154
155        ivec2 v = dst_row[tid];
156
157        // We need to compact these values to the start of the dst_row array.
158        // Have each subgroup count how many items it'll store, so other
159        // subgroups can compute their base offset.
160        // Values strictly greater than range_min must be stored. For values equal
161        // to range_min, there can be ties and it's possible we'll need to store
162        // an arbitrary subset of them.
163        // If total == p.k, have a fast path where we don't need to handle ties.
164        if (total == p.k) {
165            bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
166            uvec4 b = subgroupBallot(top);
167            uint bit_count = subgroupBallotBitCount(b);
168            if ((tid % SUBGROUP_SIZE) == 0) {
169                offset_partials[tid / SUBGROUP_SIZE] = bit_count;
170            }
171            barrier();
172
173            uint out_idx = 0;
174            [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
175                if (i < tid / SUBGROUP_SIZE) {
176                    out_idx += offset_partials[i];
177                }
178            }
179
180            uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
181            if (top) {
182                // TODO: Copy directly to the output?
183                dst_row[out_idx + bit_count_ex] = v;
184            }
185        } else {
186            bool top = f2ui(intBitsToFloat(v.y)) > range_min;
187            bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
188            uvec4 b_top = subgroupBallot(top);
189            uvec4 b_eq_min = subgroupBallot(eq_min);
190            uint bit_count_top = subgroupBallotBitCount(b_top);
191            uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
192            if ((tid % SUBGROUP_SIZE) == 0) {
193                offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
194                eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
195            }
196            barrier();
197
198            uint out_idx = 0;
199            uint eq_min_base = 0;
200            uint eq_min_idx = 0;
201            [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
202                if (i < tid / SUBGROUP_SIZE) {
203                    out_idx += offset_partials[i];
204                    eq_min_idx += eq_min_partials[i];
205                }
206                eq_min_base += offset_partials[i];
207            }
208            // range_min values are stored at the end
209            eq_min_idx += eq_min_base;
210
211            uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
212            uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
213            if (top) {
214                // TODO: Copy directly to the output?
215                dst_row[out_idx + bit_count_ex_top] = v;
216            }
217            if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
218                dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
219            }
220        }
221
222        barrier();
223    }
224
225    if (tid < p.k) {
226        if (p.last_pass != 0) {
227            if (gl_GlobalInvocationID.x < p.ncols_input) {
228                const uint row_offset = row * p.k;
229                data_d[row_offset + tid] = dst_row[tid].x;
230            }
231        } else {
232            if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
233                const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
234                data_t[row_offset + tid] = dst_row[tid];
235            }
236        }
237    }
238}
239
240void main() {
241    uint row = gl_WorkGroupID.y;
242    while (row < p.nrows) {
243        topk(row);
244        row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
245    }
246}