aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl90
1 files changed, 90 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
new file mode 100644
index 0000000..040e80d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
@@ -0,0 +1,90 @@
1#define(VARIANTS)
2
3[
4 {
5 "SHADER_NAME": "scale_f32",
6 "DECLS": ["NOT_INPLACE"]
7 },
8 {
9 "SHADER_NAME": "scale_f32_inplace",
10 "DECLS": ["INPLACE"]
11 }
12]
13
14#end(VARIANTS)
15
16#define(DECLS)
17
18#decl(NOT_INPLACE)
19@group(0) @binding(1)
20var<storage, read_write> dst: array<f32>;
21
22@group(0) @binding(2)
23var<uniform> params: Params;
24
25fn store_scale(val: f32, offset: u32) {
26 dst[offset] = val;
27}
28#enddecl(NOT_INPLACE)
29
30#decl(INPLACE)
31@group(0) @binding(1)
32var<uniform> params: Params;
33
34fn store_scale(val: f32, offset: u32) {
35 src[offset] = val;
36}
37#enddecl(INPLACE)
38
39#end(DECLS)
40
41#define(SHADER)
42
43struct Params {
44 offset_src: u32,
45 offset_dst: u32,
46
47 // Strides (in elements)
48 stride_src1: u32,
49 stride_src2: u32,
50 stride_src3: u32,
51
52 stride_dst1: u32,
53 stride_dst2: u32,
54 stride_dst3: u32,
55
56 ne: u32,
57 ne0: u32,
58 ne1: u32,
59 ne2: u32,
60
61 scale: f32,
62 bias: f32
63};
64
65@group(0) @binding(0)
66var<storage, read_write> src: array<f32>;
67
68DECLS
69
70override wg_size: u32;
71@compute @workgroup_size(wg_size)
72fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
73 if (gid.x >= params.ne) {
74 return;
75 }
76
77 var i = gid.x;
78 let i3 = i / (params.ne2 * params.ne1 * params.ne0);
79 i = i % (params.ne2 * params.ne1 * params.ne0);
80 let i2 = i / (params.ne1 * params.ne0);
81 i = i % (params.ne1 * params.ne0);
82 let i1 = i / params.ne0;
83 let i0 = i % params.ne0;
84
85 let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
86 let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
87
88 store_scale(src[i_src] * params.scale + params.bias, i_dst);
89}
90#end(SHADER)