summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl107
1 files changed, 107 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
new file mode 100644
index 0000000..b5e93b8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
@@ -0,0 +1,107 @@
+#define(VARIANTS)
+
+[
+ {
+ "REPLS": {
+ "SRC_TYPE": "f32",
+ "DST_TYPE": "f32"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f32",
+ "DST_TYPE": "i32"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f32",
+ "DST_TYPE": "f16"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f16",
+ "DST_TYPE": "f16"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f16",
+ "DST_TYPE": "f32"
+ }
+ }
+]
+
+#end(VARIANTS)
+
+#define(SHADER)
+enable f16;
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<{{SRC_TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<{{DST_TYPE}}>;
+
+struct Params {
+ ne: u32, // total number of elements
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements) — may be permuted
+ stride_src0: u32,
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst0: u32,
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Logical shapes
+ src_ne0: u32,
+ src_ne1: u32,
+ src_ne2: u32,
+
+ dst_ne0: u32,
+ dst_ne1: u32,
+ dst_ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
+ i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
+ let i2 = i / (params.src_ne1 * params.src_ne0);
+ i = i % (params.src_ne1 * params.src_ne0);
+ let i1 = i / params.src_ne0;
+ let i0 = i % params.src_ne0;
+
+ var j = gid.x;
+ let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+ j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+ let j2 = j / (params.dst_ne1 * params.dst_ne0);
+ j = j % (params.dst_ne1 * params.dst_ne0);
+ let j1 = j / params.dst_ne0;
+ let j0 = j % params.dst_ne0;
+
+ let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+ i2 * params.stride_src2 + i3 * params.stride_src3;
+
+ let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
+ j2 * params.stride_dst2 + j3 * params.stride_dst3;
+
+ dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
+}
+#end(SHADER)