aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl86
1 files changed, 86 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
new file mode 100644
index 0000000..ea63b9a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
@@ -0,0 +1,86 @@
1@group(0) @binding(0)
2var<storage, read_write> src: array<f32>;
3
4@group(0) @binding(1)
5var<storage, read_write> dst: array<f32>;
6
7struct Params {
8 ne: u32, // total number of elements
9 offset_src: u32, // in elements
10 offset_dst: u32, // in elements
11
12 // Strides (in elements)
13 stride_src0: u32,
14 stride_src1: u32,
15 stride_src2: u32,
16 stride_src3: u32,
17
18 // Logical shapes
19 src_ne0: u32,
20 src_ne1: u32,
21 src_ne2: u32,
22 src_ne3: u32,
23
24 dst_ne0: u32,
25 dst_ne1: u32,
26 dst_ne2: u32,
27 dst_ne3: u32,
28
29 // Pad sizes (in elements)
30 lp0: u32,
31 rp0: u32,
32 lp1: u32,
33 rp1: u32,
34 lp2: u32,
35 rp2: u32,
36 lp3: u32,
37 rp3: u32,
38};
39
40@group(0) @binding(2)
41var<uniform> params: Params;
42
43fn wrap_around(idx: i32, n: u32) -> u32 {
44 return u32(idx + i32(n)) % n;
45}
46
47@compute @workgroup_size(WG_SIZE)
48fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
49 if (gid.x >= params.ne) {
50 return;
51 }
52
53 var i = gid.x;
54 let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
55 let i3 = i / dst_plane;
56 i = i % dst_plane;
57 let i2 = i / (params.dst_ne1 * params.dst_ne0);
58 i = i % (params.dst_ne1 * params.dst_ne0);
59 let i1 = i / params.dst_ne0;
60 let i0 = i % params.dst_ne0;
61
62 var value: f32 = 0.0;
63
64#ifdef CIRCULAR
65 let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
66 let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
67 let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
68 let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
69 let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
70 ci2 * params.stride_src2 + ci3 * params.stride_src3;
71 value = src[params.offset_src + circular_src_idx];
72#else
73 let is_src =
74 (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
75 (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
76 (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
77 (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
78 if (is_src) {
79 let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
80 (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
81 value = src[params.offset_src + src_idx];
82 }
83#endif
84
85 dst[params.offset_dst + gid.x] = value;
86}