1#version 450
  2#extension GL_EXT_control_flow_attributes : enable
  3
  4#include "types.glsl"
  5
  6layout(constant_id = 0) const int BLOCK_SIZE = 1024;
  7layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
  8
  9layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 10
 11// Input can either be the source (A) or intermediate values (S).
 12// Similarly, output can be either destination (D) or intermediate values (S).
 13layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 14layout (binding = 0) readonly buffer S {ivec2 data_s[];};
 15layout (binding = 1) writeonly buffer D {int data_d[];};
 16layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
 17
 18layout (push_constant) uniform parameter {
 19    uint orig_ncols;
 20    uint ncols_input;
 21    uint ncols_output;
 22    uint k;
 23    uint nrows;
 24    uint first_pass;
 25    uint last_pass;
 26} p;
 27
 28// pairs of (gid, value)
 29shared ivec2 dst_row[BLOCK_SIZE];
 30
 31void topk(bool needs_bounds_check, const uint row) {
 32    const int col = int(gl_LocalInvocationID.x);
 33
 34    // initialize indices
 35    if (gl_GlobalInvocationID.x < p.ncols_input) {
 36        if (p.first_pass != 0) {
 37            const uint row_offset = row * p.ncols_input;
 38            dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
 39        } else {
 40            const uint row_offset = row * p.ncols_input;
 41            dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
 42        }
 43    } else {
 44        dst_row[col] = ivec2(p.orig_ncols, 0);
 45    }
 46    barrier();
 47
 48    if (p.k == 1) {
 49        // Fast path for single output - just do a max reduction
 50        [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
 51            if (col < s) {
 52                ivec2 a = dst_row[col];
 53                ivec2 b = dst_row[col + s];
 54                if (a.x >= p.orig_ncols ||
 55                    b.x < p.orig_ncols && b.y > a.y) {
 56                    dst_row[col] = b;
 57                }
 58            }
 59            barrier();
 60        }
 61    } else {
 62        // bitonic sort on this group of elements
 63        uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
 64        for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
 65            uint num_inner_loop_iters = outer_idx + 1;
 66            for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
 67                const int ixj = int(col ^ j);
 68
 69                int idx_0 = (col & k) == 0 ? col : ixj;
 70                int idx_1 = (col & k) == 0 ? ixj : col;
 71
 72                ivec2 sh_idx_0 = dst_row[idx_0];
 73                ivec2 sh_idx_1 = dst_row[idx_1];
 74                bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
 75                bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
 76
 77                if ((idx_0_oob ||
 78                    (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
 79                    dst_row[idx_0] = sh_idx_1;
 80                    dst_row[idx_1] = sh_idx_0;
 81                }
 82
 83                barrier();
 84            }
 85        }
 86    }
 87
 88    if (col < p.k) {
 89        if (p.last_pass != 0) {
 90            if (gl_GlobalInvocationID.x < p.ncols_input) {
 91                const uint row_offset = row * p.k;
 92                data_d[row_offset + col] = dst_row[col].x;
 93            }
 94        } else {
 95            if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
 96                const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
 97                data_t[row_offset + col] = dst_row[col];
 98            }
 99        }
100    }
101}
102
103void main() {
104    // Fast path for fully occupied workgroups
105    if ((p.ncols_input % BLOCK_SIZE) == 0) {
106        uint row = gl_WorkGroupID.y;
107        while (row < p.nrows) {
108            topk(false, row);
109            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
110        }
111    } else {
112        uint row = gl_WorkGroupID.y;
113        while (row < p.nrows) {
114            topk(true, row);
115            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
116        }
117    }
118}