1#version 450
  2#extension GL_EXT_control_flow_attributes : enable
  3#extension GL_KHR_memory_scope_semantics : enable
  4#pragma use_vulkan_memory_model
  5
  6#include "types.glsl"
  7
  8layout(constant_id = 0) const int BLOCK_SIZE = 1024;
  9layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
 10#define ASC 0
 11
 12layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 13
 14layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 15layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
 16layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
 17
 18layout (push_constant) uniform parameter {
 19    uint ncols;
 20    uint ncols_padded;
 21    uint ncols_padded_log2;
 22    uint nrows;
 23    uint order;
 24    uint outer_start;
 25    uint outer_end;
 26    uint inner_start;
 27    uint inner_end;
 28} p;
 29
 30void argsort(bool needs_bounds_check, const uint row) {
 31    // bitonic sort
 32    int col = int(gl_GlobalInvocationID.x);
 33    col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
 34
 35    const uint row_offset = row * p.ncols;
 36    uint idx_offset = row * p.ncols_padded;
 37
 38    bool need_barrier = false;
 39
 40    // initialize indices
 41    if (p.outer_start == 0 && p.inner_start == 0) {
 42        [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
 43            uint c = u*BLOCK_SIZE + col;
 44            if (c < p.ncols_padded) {
 45                ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
 46                tmp_idx[idx_offset + c] = v;
 47            }
 48        }
 49        need_barrier = true;
 50    }
 51
 52    [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
 53        uint inner_end = min(p.inner_end, outer_idx + 1);
 54        for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
 55            if (need_barrier) {
 56                controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
 57            }
 58            need_barrier = true;
 59            [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
 60                int c = u*BLOCK_SIZE + col;
 61                const int ixj = int(c ^ j);
 62
 63                if (ixj < c) {
 64                    continue;
 65                }
 66
 67                int idx_0 = (c & k) == 0 ? c : ixj;
 68                int idx_1 = (c & k) == 0 ? ixj : c;
 69
 70                ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
 71                ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
 72                bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
 73                bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
 74
 75                if ((idx_0_oob ||
 76                    (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
 77                    tmp_idx[idx_offset + idx_0] = sh_idx_1;
 78                    tmp_idx[idx_offset + idx_1] = sh_idx_0;
 79                }
 80            }
 81        }
 82    }
 83
 84    if (p.outer_end == p.ncols_padded_log2 &&
 85        p.inner_end >= p.ncols_padded_log2 + 1) {
 86        controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
 87        [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
 88            uint c = u*BLOCK_SIZE + col;
 89            if (c < p.ncols) {
 90                if (p.order == ASC) {
 91                    data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
 92                } else {
 93                    data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
 94                }
 95            }
 96        }
 97    }
 98}
 99
100void main() {
101    if (p.ncols == p.ncols_padded) {
102        uint row = gl_WorkGroupID.y;
103        while (row < p.nrows) {
104            argsort(false, row);
105            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
106        }
107    } else {
108        uint row = gl_WorkGroupID.y;
109        while (row < p.nrows) {
110            argsort(true, row);
111            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
112        }
113    }
114}