1#version 450
 2#extension GL_EXT_control_flow_attributes : enable
 3
 4#include "types.glsl"
 5
 6layout(constant_id = 0) const int BLOCK_SIZE = 1024;
 7layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
 8#define ASC 0
 9
10layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
11
12layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13layout (binding = 2) writeonly buffer D {int data_d[];};
14
15layout (push_constant) uniform parameter {
16    uint ncols;
17    uint ncols_padded;
18    uint ncols_padded_log2;
19    uint nrows;
20    uint order;
21    uint outer_start;
22    uint outer_end;
23    uint inner_start;
24    uint inner_end;
25} p;
26
27shared ivec2 dst_row[BLOCK_SIZE];
28
29void argsort(bool needs_bounds_check, const uint row) {
30    // bitonic sort
31    const int col = int(gl_LocalInvocationID.x);
32
33    const uint row_offset = row * p.ncols;
34
35    // initialize indices
36    dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
37    barrier();
38
39    uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
40    [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
41        uint num_inner_loop_iters = outer_idx + 1;
42        [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
43            const int ixj = int(col ^ j);
44
45            int idx_0 = (col & k) == 0 ? col : ixj;
46            int idx_1 = (col & k) == 0 ? ixj : col;
47
48            ivec2 sh_idx_0 = dst_row[idx_0];
49            ivec2 sh_idx_1 = dst_row[idx_1];
50            bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
51            bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
52
53            if ((idx_0_oob ||
54                (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
55                dst_row[idx_0] = sh_idx_1;
56                dst_row[idx_1] = sh_idx_0;
57            }
58
59            barrier();
60        }
61    }
62
63    if (col < p.ncols) {
64        if (p.order == ASC) {
65            data_d[row_offset + col] = dst_row[col].x;
66        } else {
67            data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
68        }
69    }
70}
71
72void main() {
73    if (p.ncols == BLOCK_SIZE) {
74        uint row = gl_WorkGroupID.y;
75        while (row < p.nrows) {
76            argsort(false, row);
77            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
78        }
79    } else {
80        uint row = gl_WorkGroupID.y;
81        while (row < p.nrows) {
82            argsort(true, row);
83            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
84        }
85    }
86}