aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl107
1 files changed, 107 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
new file mode 100644
index 0000000..b5e93b8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
@@ -0,0 +1,107 @@
1#define(VARIANTS)
2
3[
4 {
5 "REPLS": {
6 "SRC_TYPE": "f32",
7 "DST_TYPE": "f32"
8 }
9 },
10 {
11 "REPLS": {
12 "SRC_TYPE": "f32",
13 "DST_TYPE": "i32"
14 }
15 },
16 {
17 "REPLS": {
18 "SRC_TYPE": "f32",
19 "DST_TYPE": "f16"
20 }
21 },
22 {
23 "REPLS": {
24 "SRC_TYPE": "f16",
25 "DST_TYPE": "f16"
26 }
27 },
28 {
29 "REPLS": {
30 "SRC_TYPE": "f16",
31 "DST_TYPE": "f32"
32 }
33 }
34]
35
36#end(VARIANTS)
37
38#define(SHADER)
39enable f16;
40
41@group(0) @binding(0)
42var<storage, read_write> src: array<{{SRC_TYPE}}>;
43
44@group(0) @binding(1)
45var<storage, read_write> dst: array<{{DST_TYPE}}>;
46
47struct Params {
48 ne: u32, // total number of elements
49 offset_src: u32, // in elements
50 offset_dst: u32, // in elements
51
52 // Strides (in elements) — may be permuted
53 stride_src0: u32,
54 stride_src1: u32,
55 stride_src2: u32,
56 stride_src3: u32,
57
58 stride_dst0: u32,
59 stride_dst1: u32,
60 stride_dst2: u32,
61 stride_dst3: u32,
62
63 // Logical shapes
64 src_ne0: u32,
65 src_ne1: u32,
66 src_ne2: u32,
67
68 dst_ne0: u32,
69 dst_ne1: u32,
70 dst_ne2: u32
71};
72
73@group(0) @binding(2)
74var<uniform> params: Params;
75
76override wg_size: u32;
77@compute @workgroup_size(wg_size)
78fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
79 if (gid.x >= params.ne) {
80 return;
81 }
82
83 var i = gid.x;
84 let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
85 i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
86 let i2 = i / (params.src_ne1 * params.src_ne0);
87 i = i % (params.src_ne1 * params.src_ne0);
88 let i1 = i / params.src_ne0;
89 let i0 = i % params.src_ne0;
90
91 var j = gid.x;
92 let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
93 j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
94 let j2 = j / (params.dst_ne1 * params.dst_ne0);
95 j = j % (params.dst_ne1 * params.dst_ne0);
96 let j1 = j / params.dst_ne0;
97 let j0 = j % params.dst_ne0;
98
99 let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
100 i2 * params.stride_src2 + i3 * params.stride_src3;
101
102 let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
103 j2 * params.stride_dst2 + j3 * params.stride_dst3;
104
105 dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
106}
107#end(SHADER)