diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp new file mode 100644 index 0000000..0fc2b9b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -0,0 +1,86 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 2) writeonly buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_padded; + uint ncols_padded_log2; + uint nrows; + uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; +} p; + +shared ivec2 dst_row[BLOCK_SIZE]; + +void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + + const uint row_offset = row * p.ncols; + + // initialize indices + dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col])); + barrier(); + + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; + [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; + } + + barrier(); + } + } + + if (col < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + col] = dst_row[col].x; + } else { + data_d[row_offset + p.ncols - col - 1] = dst_row[col].x; + } + } +} + +void main() { + if (p.ncols == BLOCK_SIZE) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} |
