summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl134
1 files changed, 134 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
new file mode 100644
index 0000000..9a77f6e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
@@ -0,0 +1,134 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx_in: array<i32>;
+
+@group(0) @binding(2)
+var<storage, read_write> idx_out: array<i32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_in: u32, // in elements
+ offset_out: u32, // in elements
+
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_idx1: u32,
+ stride_idx2: u32,
+ stride_idx3: u32,
+
+ stride_out1: u32,
+ stride_out2: u32,
+ stride_out3: u32,
+
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ top_k: u32,
+
+ len: u32,
+ nm: u32,
+ nrows: u32
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
+ let a_val = src[row_base + u32(a_idx)];
+ let b_val = src[row_base + u32(b_idx)];
+#if ORDER == 0
+ return a_val <= b_val;
+#else
+ return a_val >= b_val;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let linear = wid.x + wid.y * num_wg.x;
+ // guard against overprovisioned workgroups
+ if (linear >= params.nm * params.nrows) {
+ return;
+ }
+
+ let start = (linear % params.nm) * params.len * 2;
+ let len0 = min(params.len, params.ne0 - start);
+ let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
+ let len1 = min(params.len, rem1);
+ let total = len0 + len1;
+ let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
+ let k0 = lid.x * chunk;
+ let k1 = min(min(k0 + chunk, total), params.top_k);
+ // guard against overprovisioned threads
+ if (k0 >= params.top_k || k0 >= total) {
+ return;
+ }
+
+ var row = linear / params.nm;
+ let i3 = row / (params.ne2 * params.ne1);
+ row = row % (params.ne2 * params.ne1);
+ let i2 = row / params.ne1;
+ let i1 = row % params.ne1;
+
+ let row_src = params.offset_src +
+ i1 * params.stride_src1 +
+ i2 * params.stride_src2 +
+ i3 * params.stride_src3;
+
+ let row_in = params.offset_in +
+ i1 * params.stride_idx1 +
+ i2 * params.stride_idx2 +
+ i3 * params.stride_idx3;
+
+ let row_out = params.offset_out +
+ i1 * params.stride_out1 +
+ i2 * params.stride_out2 +
+ i3 * params.stride_out3;
+
+
+ var low: u32 = select(0, k0 - len1, k0 > len1);
+ var high: u32 = min(k0, len0);
+
+ while (low < high) {
+ let mid = (low + high) >> 1;
+ let idx0 = idx_in[row_in + start + mid];
+ let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
+ if (take_left(idx0, idx1, row_src)) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ }
+
+ var i = low;
+ var j = k0 - i;
+ var k = k0;
+ while (k < k1) {
+ var take_l = false;
+ if (i >= len0) {
+ take_l = false;
+ } else if (j >= len1) {
+ take_l = true;
+ } else {
+ let idx0 = idx_in[row_in + start + i];
+ let idx1 = idx_in[row_in + start + params.len + j];
+ take_l = take_left(idx0, idx1, row_src);
+ }
+
+ let out_idx = select(
+ idx_in[row_in + start + params.len + j],
+ idx_in[row_in + start + i],
+ take_l);
+ idx_out[row_out + start + k] = out_idx;
+ i = select(i, i + 1, take_l);
+ j = select(j + 1, j, take_l);
+ k += 1;
+ }
+}