1#version 450
2#extension GL_EXT_control_flow_attributes : enable
3
4#include "types.glsl"
5
6layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
8
9layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10
11// Input can either be the source (A) or intermediate values (S).
12// Similarly, output can be either destination (D) or intermediate values (S).
13layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
14layout (binding = 0) readonly buffer S {ivec2 data_s[];};
15layout (binding = 1) writeonly buffer D {int data_d[];};
16layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
17
18layout (push_constant) uniform parameter {
19 uint orig_ncols;
20 uint ncols_input;
21 uint ncols_output;
22 uint k;
23 uint nrows;
24 uint first_pass;
25 uint last_pass;
26} p;
27
28// pairs of (gid, value)
29shared ivec2 dst_row[BLOCK_SIZE];
30
31void topk(bool needs_bounds_check, const uint row) {
32 const int col = int(gl_LocalInvocationID.x);
33
34 // initialize indices
35 if (gl_GlobalInvocationID.x < p.ncols_input) {
36 if (p.first_pass != 0) {
37 const uint row_offset = row * p.ncols_input;
38 dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
39 } else {
40 const uint row_offset = row * p.ncols_input;
41 dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
42 }
43 } else {
44 dst_row[col] = ivec2(p.orig_ncols, 0);
45 }
46 barrier();
47
48 if (p.k == 1) {
49 // Fast path for single output - just do a max reduction
50 [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
51 if (col < s) {
52 ivec2 a = dst_row[col];
53 ivec2 b = dst_row[col + s];
54 if (a.x >= p.orig_ncols ||
55 b.x < p.orig_ncols && b.y > a.y) {
56 dst_row[col] = b;
57 }
58 }
59 barrier();
60 }
61 } else {
62 // bitonic sort on this group of elements
63 uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
64 for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
65 uint num_inner_loop_iters = outer_idx + 1;
66 for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
67 const int ixj = int(col ^ j);
68
69 int idx_0 = (col & k) == 0 ? col : ixj;
70 int idx_1 = (col & k) == 0 ? ixj : col;
71
72 ivec2 sh_idx_0 = dst_row[idx_0];
73 ivec2 sh_idx_1 = dst_row[idx_1];
74 bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
75 bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
76
77 if ((idx_0_oob ||
78 (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
79 dst_row[idx_0] = sh_idx_1;
80 dst_row[idx_1] = sh_idx_0;
81 }
82
83 barrier();
84 }
85 }
86 }
87
88 if (col < p.k) {
89 if (p.last_pass != 0) {
90 if (gl_GlobalInvocationID.x < p.ncols_input) {
91 const uint row_offset = row * p.k;
92 data_d[row_offset + col] = dst_row[col].x;
93 }
94 } else {
95 if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
96 const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
97 data_t[row_offset + col] = dst_row[col];
98 }
99 }
100 }
101}
102
103void main() {
104 // Fast path for fully occupied workgroups
105 if ((p.ncols_input % BLOCK_SIZE) == 0) {
106 uint row = gl_WorkGroupID.y;
107 while (row < p.nrows) {
108 topk(false, row);
109 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
110 }
111 } else {
112 uint row = gl_WorkGroupID.y;
113 while (row < p.nrows) {
114 topk(true, row);
115 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
116 }
117 }
118}