summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl72
1 files changed, 72 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
new file mode 100644
index 0000000..ca5bfcc
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
@@ -0,0 +1,72 @@
+@group(0) @binding(0)
+#ifdef VEC4
+var<storage, read_write> src: array<vec4<f32>>;
+#define VEC_SIZE 4
+#else
+var<storage, read_write> src: array<f32>;
+#define VEC_SIZE 1
+#endif
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<i32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+ ne0: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+const FLOAT_MIN: f32 = -1.0e9;
+
+struct Pair {
+ value: f32,
+ index: i32
+};
+
+var<workgroup> shared_max: array<Pair, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let row_idx = params.offset_src + wid.x * params.ne0;
+ var local_pair = Pair(FLOAT_MIN, -1);
+#ifdef VEC4
+ for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
+ let vec_val = src[row_idx / VEC_SIZE + col];
+ for (var v = 0u; v < VEC_SIZE; v++) {
+ let val = vec_val[v];
+ if (val >= local_pair.value) {
+ local_pair = Pair(val, i32(col * VEC_SIZE + v));
+ }
+ }
+ }
+#else
+ for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+ if (src[row_idx + col] >= local_pair.value) {
+ local_pair = Pair(src[row_idx + col], i32(col));
+ }
+ }
+#endif
+ shared_max[lid.x] = local_pair;
+ workgroupBarrier();
+ var offset: u32 = WG_SIZE >> 1;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ let a = shared_max[lid.x];
+ let b = shared_max[lid.x + offset];
+ if (b.value > a.value) {
+ shared_max[lid.x] = b;
+ } else if (b.value == a.value && b.index > a.index) {
+ shared_max[lid.x] = b;
+ }
+ }
+ workgroupBarrier();
+ offset >>= 1;
+ }
+ if (lid.x == 0u) {
+ dst[params.offset_dst + wid.x] = shared_max[0].index;
+ }
+}