aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl134
1 files changed, 134 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
new file mode 100644
index 0000000..9a77f6e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
@@ -0,0 +1,134 @@
1@group(0) @binding(0)
2var<storage, read_write> src: array<f32>;
3
4@group(0) @binding(1)
5var<storage, read_write> idx_in: array<i32>;
6
7@group(0) @binding(2)
8var<storage, read_write> idx_out: array<i32>;
9
10struct Params {
11 offset_src: u32, // in elements
12 offset_in: u32, // in elements
13 offset_out: u32, // in elements
14
15 stride_src1: u32,
16 stride_src2: u32,
17 stride_src3: u32,
18
19 stride_idx1: u32,
20 stride_idx2: u32,
21 stride_idx3: u32,
22
23 stride_out1: u32,
24 stride_out2: u32,
25 stride_out3: u32,
26
27 ne0: u32,
28 ne1: u32,
29 ne2: u32,
30
31 top_k: u32,
32
33 len: u32,
34 nm: u32,
35 nrows: u32
36};
37
38@group(0) @binding(3)
39var<uniform> params: Params;
40
41fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
42 let a_val = src[row_base + u32(a_idx)];
43 let b_val = src[row_base + u32(b_idx)];
44#if ORDER == 0
45 return a_val <= b_val;
46#else
47 return a_val >= b_val;
48#endif
49}
50
51@compute @workgroup_size(WG_SIZE)
52fn main(@builtin(workgroup_id) wid: vec3<u32>,
53 @builtin(num_workgroups) num_wg: vec3<u32>,
54 @builtin(local_invocation_id) lid: vec3<u32>) {
55 let linear = wid.x + wid.y * num_wg.x;
56 // guard against overprovisioned workgroups
57 if (linear >= params.nm * params.nrows) {
58 return;
59 }
60
61 let start = (linear % params.nm) * params.len * 2;
62 let len0 = min(params.len, params.ne0 - start);
63 let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
64 let len1 = min(params.len, rem1);
65 let total = len0 + len1;
66 let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
67 let k0 = lid.x * chunk;
68 let k1 = min(min(k0 + chunk, total), params.top_k);
69 // guard against overprovisioned threads
70 if (k0 >= params.top_k || k0 >= total) {
71 return;
72 }
73
74 var row = linear / params.nm;
75 let i3 = row / (params.ne2 * params.ne1);
76 row = row % (params.ne2 * params.ne1);
77 let i2 = row / params.ne1;
78 let i1 = row % params.ne1;
79
80 let row_src = params.offset_src +
81 i1 * params.stride_src1 +
82 i2 * params.stride_src2 +
83 i3 * params.stride_src3;
84
85 let row_in = params.offset_in +
86 i1 * params.stride_idx1 +
87 i2 * params.stride_idx2 +
88 i3 * params.stride_idx3;
89
90 let row_out = params.offset_out +
91 i1 * params.stride_out1 +
92 i2 * params.stride_out2 +
93 i3 * params.stride_out3;
94
95
96 var low: u32 = select(0, k0 - len1, k0 > len1);
97 var high: u32 = min(k0, len0);
98
99 while (low < high) {
100 let mid = (low + high) >> 1;
101 let idx0 = idx_in[row_in + start + mid];
102 let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
103 if (take_left(idx0, idx1, row_src)) {
104 low = mid + 1;
105 } else {
106 high = mid;
107 }
108 }
109
110 var i = low;
111 var j = k0 - i;
112 var k = k0;
113 while (k < k1) {
114 var take_l = false;
115 if (i >= len0) {
116 take_l = false;
117 } else if (j >= len1) {
118 take_l = true;
119 } else {
120 let idx0 = idx_in[row_in + start + i];
121 let idx1 = idx_in[row_in + start + params.len + j];
122 take_l = take_left(idx0, idx1, row_src);
123 }
124
125 let out_idx = select(
126 idx_in[row_in + start + params.len + j],
127 idx_in[row_in + start + i],
128 take_l);
129 idx_out[row_out + start + k] = out_idx;
130 i = select(i, i + 1, take_l);
131 j = select(j + 1, j, take_l);
132 k += 1;
133 }
134}