summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl109
1 files changed, 109 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
new file mode 100644
index 0000000..99e9192
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
@@ -0,0 +1,109 @@
+enable f16;
+
+#ifdef DST_F32
+#define DST_INNER_TYPE f32
+#else
+#define DST_INNER_TYPE f16
+#endif
+
+#ifdef VEC4
+#define SRC_TYPE vec4<f32>
+#define DST_TYPE vec4<DST_INNER_TYPE>
+#define VEC_SIZE 4
+#else
+#define SRC_TYPE f32
+#define DST_TYPE DST_INNER_TYPE
+#define VEC_SIZE 1
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<SRC_TYPE>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx: array<u32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DST_TYPE>;
+
+#ifdef I64_IDX
+@group(0) @binding(3)
+var<storage, read_write> error: atomic<u32>;
+#define PARAMS_BINDING 4
+#else
+#define PARAMS_BINDING 3
+#endif
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_idx: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_idx0: u32,
+ stride_idx1: u32,
+ stride_idx2: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Shape of src
+ ne0: u32,
+ n_rows: u32,
+ ne2: u32,
+ ne3: u32,
+
+ // Shape of idx
+ idx1: u32,
+ idx2: u32,
+};
+
+@group(0) @binding(PARAMS_BINDING)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
+ return;
+ }
+
+ // getting the row from gid
+ let elems_per_row = params.ne0 / VEC_SIZE;
+ var i = gid.x / elems_per_row;
+
+ let i_src3 = i / (params.ne2 * params.n_rows);
+
+ i = i % (params.ne2 * params.n_rows);
+ let i_src2 = i / params.n_rows;
+ let i_src1 = i % params.n_rows;
+
+ let i_idx2 = i_src3 % params.idx2;
+ let i_idx1 = i_src2 % params.idx1;
+ let i_idx0 = i_src1;
+
+#ifdef I64_IDX
+ let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
+
+ let idx_val = idx[idx_high];
+ let idx_low_val = idx[idx_high + 1];
+
+ if (idx_low_val != 0) {
+ // Upper bits of index are not zero, output will be incorrect
+ atomicStore(&error, 1);
+ return;
+ }
+#else
+ let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+ let idx_val = idx[idx_i];
+#endif
+
+ let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
+ let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
+
+ let col_idx = (gid.x % elems_per_row);
+ dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
+}