1#version 450
2#extension GL_EXT_control_flow_attributes : enable
3
4#include "types.glsl"
5
6layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
8#define ASC 0
9
10layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
11
12layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13layout (binding = 2) writeonly buffer D {int data_d[];};
14
15layout (push_constant) uniform parameter {
16 uint ncols;
17 uint ncols_padded;
18 uint ncols_padded_log2;
19 uint nrows;
20 uint order;
21 uint outer_start;
22 uint outer_end;
23 uint inner_start;
24 uint inner_end;
25} p;
26
27shared ivec2 dst_row[BLOCK_SIZE];
28
29void argsort(bool needs_bounds_check, const uint row) {
30 // bitonic sort
31 const int col = int(gl_LocalInvocationID.x);
32
33 const uint row_offset = row * p.ncols;
34
35 // initialize indices
36 dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
37 barrier();
38
39 uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
40 [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
41 uint num_inner_loop_iters = outer_idx + 1;
42 [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
43 const int ixj = int(col ^ j);
44
45 int idx_0 = (col & k) == 0 ? col : ixj;
46 int idx_1 = (col & k) == 0 ? ixj : col;
47
48 ivec2 sh_idx_0 = dst_row[idx_0];
49 ivec2 sh_idx_1 = dst_row[idx_1];
50 bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
51 bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
52
53 if ((idx_0_oob ||
54 (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
55 dst_row[idx_0] = sh_idx_1;
56 dst_row[idx_1] = sh_idx_0;
57 }
58
59 barrier();
60 }
61 }
62
63 if (col < p.ncols) {
64 if (p.order == ASC) {
65 data_d[row_offset + col] = dst_row[col].x;
66 } else {
67 data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
68 }
69 }
70}
71
72void main() {
73 if (p.ncols == BLOCK_SIZE) {
74 uint row = gl_WorkGroupID.y;
75 while (row < p.nrows) {
76 argsort(false, row);
77 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
78 }
79 } else {
80 uint row = gl_WorkGroupID.y;
81 while (row < p.nrows) {
82 argsort(true, row);
83 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
84 }
85 }
86}