aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl179
1 files changed, 179 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
new file mode 100644
index 0000000..d639d98
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
@@ -0,0 +1,179 @@
1#ifdef TYPE_F16
2enable f16;
3#define TYPE f16
4#else
5#define TYPE f32
6#endif
7
8
9@group(0) @binding(0)
10var<storage, read_write> src: array<TYPE>;
11
12#ifndef INPLACE
13@group(0) @binding(1)
14var<storage, read_write> dst: array<TYPE>;
15#define PARAMS_BINDING 2
16#else
17#define PARAMS_BINDING 1
18#endif
19
20struct Params {
21 ne: u32, // total number of elements
22 offset_src: u32, // in elements
23 offset_dst: u32, // in elements
24
25 // Strides (in elements)
26 stride_src0: u32,
27 stride_src1: u32,
28 stride_src2: u32,
29 stride_src3: u32,
30
31 // Logical shapes
32 ne0: u32,
33 ne1: u32,
34 ne2: u32,
35#ifdef CLAMP
36 clamp_min: f32,
37 clamp_max: f32,
38#endif
39#ifdef FILL
40 fill_val: f32,
41#endif
42#ifdef XIELU
43 alpha_n: f32,
44 alpha_p: f32,
45 beta: f32,
46 eps: f32,
47#endif
48
49};
50
51@group(0) @binding(PARAMS_BINDING)
52var<uniform> params: Params;
53
54@compute @workgroup_size(WG_SIZE)
55fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
56 if (gid.x >= params.ne) {
57 return;
58 }
59 var i = gid.x;
60 let i3 = i / (params.ne2 * params.ne1 * params.ne0);
61 i = i % (params.ne2 * params.ne1 * params.ne0);
62 let i2 = i / (params.ne1 * params.ne0);
63 i = i % (params.ne1 * params.ne0);
64 let i1 = i / params.ne0;
65 let i0 = i % params.ne0;
66
67 let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
68 i2 * params.stride_src2 + i3 * params.stride_src3;
69
70#ifdef ABS
71 let res = abs(src[params.offset_src + src_idx]);
72#endif
73#ifdef SGN
74 let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
75 src[params.offset_src + src_idx] > 0.0);
76#endif
77#ifdef NEG
78 let res = -src[params.offset_src + src_idx];
79#endif
80#ifdef STEP
81 let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
82#endif
83#ifdef TANH
84 let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
85#endif
86#ifdef RELU
87 let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
88#endif
89#ifdef ELU
90 let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
91 src[params.offset_src + src_idx] > 0.0);
92#endif
93#ifdef HARDSIGMOID
94 let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
95#endif
96#ifdef SIGMOID
97 let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
98#endif
99#ifdef SILU
100 let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
101#endif
102#ifdef EXP
103 let res = exp(src[params.offset_src + src_idx]);
104#endif
105#ifdef LOG
106 let res = TYPE(log(f32(src[params.offset_src + src_idx])));
107#endif
108#ifdef CLAMP
109 let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
110#endif
111#ifdef FILL
112 let res = TYPE(params.fill_val);
113#endif
114#ifdef HARDSWISH
115 let res = src[params.offset_src + src_idx] *
116 min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
117#endif
118#ifdef GELU
119 let res = 0.5 * src[params.offset_src + src_idx] *
120 (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
121 (src[params.offset_src + src_idx] +
122 0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
123 -9.010913, 9.010913)));
124#endif
125#ifdef GELU_QUICK
126 let res = src[params.offset_src + src_idx] * 0.5 *
127 (1.0 + tanh(clamp(0.79788456 *
128 (src[params.offset_src + src_idx] +
129 0.044715 * src[params.offset_src + src_idx] *
130 src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
131 -9.010913, 9.010913)));
132#endif
133#ifdef GELU_ERF
134 let res = 0.5 * src[params.offset_src + src_idx] *
135 (1.0 + tanh(clamp(0.79788456 *
136 (src[params.offset_src + src_idx] +
137 0.044715 * src[params.offset_src + src_idx] *
138 src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
139 -9.010913, 9.010913)));
140#endif
141#ifdef XIELU
142 let res =
143 select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
144 src[params.offset_src + src_idx]) *
145 TYPE(params.alpha_n) +
146 TYPE(params.beta) * src[params.offset_src + src_idx],
147 TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
148 src[params.offset_src + src_idx] +
149 TYPE(params.beta) * src[params.offset_src + src_idx],
150 src[params.offset_src + src_idx] > 0.0);
151#endif
152#ifdef SOFTPLUS
153 let src_f32 = f32(src[params.offset_src + src_idx]);
154 let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
155#endif
156#ifdef EXPM1
157 let res = exp(src[params.offset_src + src_idx]) - 1.0;
158#endif
159#ifdef FLOOR
160 let res = floor(src[params.offset_src + src_idx]);
161#endif
162#ifdef CEIL
163 let res = ceil(src[params.offset_src + src_idx]);
164#endif
165#ifdef ROUND
166 let src_f32 = f32(src[params.offset_src + src_idx]);
167 let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
168 let res = TYPE(result);
169#endif
170#ifdef TRUNC
171 let res = trunc(src[params.offset_src + src_idx]);
172#endif
173
174#ifdef INPLACE
175 src[params.offset_src + src_idx] = res;
176#else
177 dst[params.offset_dst + gid.x] = res;
178#endif
179}