1#version 450
2#extension GL_EXT_control_flow_attributes : enable
3#extension GL_KHR_memory_scope_semantics : enable
4#pragma use_vulkan_memory_model
5
6#include "types.glsl"
7
8layout(constant_id = 0) const int BLOCK_SIZE = 1024;
9layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
10#define ASC 0
11
12layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
13
14layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
15layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
16layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
17
18layout (push_constant) uniform parameter {
19 uint ncols;
20 uint ncols_padded;
21 uint ncols_padded_log2;
22 uint nrows;
23 uint order;
24 uint outer_start;
25 uint outer_end;
26 uint inner_start;
27 uint inner_end;
28} p;
29
30void argsort(bool needs_bounds_check, const uint row) {
31 // bitonic sort
32 int col = int(gl_GlobalInvocationID.x);
33 col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
34
35 const uint row_offset = row * p.ncols;
36 uint idx_offset = row * p.ncols_padded;
37
38 bool need_barrier = false;
39
40 // initialize indices
41 if (p.outer_start == 0 && p.inner_start == 0) {
42 [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
43 uint c = u*BLOCK_SIZE + col;
44 if (c < p.ncols_padded) {
45 ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
46 tmp_idx[idx_offset + c] = v;
47 }
48 }
49 need_barrier = true;
50 }
51
52 [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
53 uint inner_end = min(p.inner_end, outer_idx + 1);
54 for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
55 if (need_barrier) {
56 controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
57 }
58 need_barrier = true;
59 [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
60 int c = u*BLOCK_SIZE + col;
61 const int ixj = int(c ^ j);
62
63 if (ixj < c) {
64 continue;
65 }
66
67 int idx_0 = (c & k) == 0 ? c : ixj;
68 int idx_1 = (c & k) == 0 ? ixj : c;
69
70 ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
71 ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
72 bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
73 bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
74
75 if ((idx_0_oob ||
76 (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
77 tmp_idx[idx_offset + idx_0] = sh_idx_1;
78 tmp_idx[idx_offset + idx_1] = sh_idx_0;
79 }
80 }
81 }
82 }
83
84 if (p.outer_end == p.ncols_padded_log2 &&
85 p.inner_end >= p.ncols_padded_log2 + 1) {
86 controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
87 [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
88 uint c = u*BLOCK_SIZE + col;
89 if (c < p.ncols) {
90 if (p.order == ASC) {
91 data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
92 } else {
93 data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
94 }
95 }
96 }
97 }
98}
99
100void main() {
101 if (p.ncols == p.ncols_padded) {
102 uint row = gl_WorkGroupID.y;
103 while (row < p.nrows) {
104 argsort(false, row);
105 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
106 }
107 } else {
108 uint row = gl_WorkGroupID.y;
109 while (row < p.nrows) {
110 argsort(true, row);
111 row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
112 }
113 }
114}