summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl106
1 files changed, 106 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
new file mode 100644
index 0000000..46ed19f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
@@ -0,0 +1,106 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<i32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // src/dst dimensions
+ src_ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ ne0: u32,
+ top_k: u32,
+
+ npr: u32, // tiles per row
+ nrows: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shmem_idx: array<u32, WG_SIZE>;
+
+#if ORDER == 0
+#define EXTREME_VALUE 1e30
+#define SWAP_COMPARE_UP >
+#define SWAP_COMPARE_DOWN <
+#else
+#define EXTREME_VALUE -1e30
+#define SWAP_COMPARE_UP <
+#define SWAP_COMPARE_DOWN >
+#endif
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let linear = wid.x + wid.y * num_wg.x;
+ // guard against overprovisioned workgroups
+ if (linear >= params.npr * params.nrows) {
+ return;
+ }
+ let tile = linear % params.npr;
+ var row = linear / params.npr;
+ let i3 = row / (params.ne2 * params.ne1);
+ row = row % (params.ne2 * params.ne1);
+ let i2 = row / params.ne1;
+ let i1 = row % params.ne1;
+
+ let row_base = params.offset_src +
+ i1 * params.stride_src1 +
+ i2 * params.stride_src2 +
+ i3 * params.stride_src3;
+
+ let tile_base = tile * WG_SIZE;
+ let idx = tile_base + lid.x;
+ shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
+ workgroupBarrier();
+
+ var k = 2u;
+ while (k <= WG_SIZE) {
+ var j = k >> 1;
+ while (j > 0) {
+ let ixj = lid.x ^ j;
+ if (ixj > lid.x) {
+ let dir_up = (lid.x & k) == 0;
+ let a_idx = shmem_idx[lid.x];
+ let b_idx = shmem_idx[ixj];
+ let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
+ let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
+ let should_swap = select(
+ (a_val SWAP_COMPARE_DOWN b_val),
+ (a_val SWAP_COMPARE_UP b_val),
+ dir_up);
+ if (should_swap) {
+ shmem_idx[lid.x] = b_idx;
+ shmem_idx[ixj] = a_idx;
+ }
+ }
+ workgroupBarrier();
+ j >>= 1;
+ }
+ k <<= 1;
+ }
+
+ let out_idx = tile * params.top_k + lid.x;
+ if (out_idx < params.ne0 && lid.x < params.top_k) {
+ let row_dst = params.offset_dst +
+ i1 * params.stride_dst1 +
+ i2 * params.stride_dst2 +
+ i3 * params.stride_dst3;
+ dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
+ }
+}