aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl345
1 files changed, 345 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
new file mode 100644
index 0000000..c74dc4c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
@@ -0,0 +1,345 @@
1#define(VARIANTS)
2[
3 {
4 "SHADER_NAME": "soft_max_f32",
5 "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
6 },
7 {
8 "SHADER_NAME": "soft_max_f32_inplace",
9 "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
10 },
11 {
12 "SHADER_NAME": "soft_max_f32_sink",
13 "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
14 },
15 {
16 "SHADER_NAME": "soft_max_f32_sink_inplace",
17 "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
18 },
19 {
20 "SHADER_NAME": "soft_max_f32_mask_f32",
21 "REPLS": {
22 "MASK_TYPE" : "f32",
23 },
24 "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
25 },
26 {
27 "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
28 "REPLS": {
29 "MASK_TYPE" : "f32",
30 },
31 "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
32 },
33 {
34 "SHADER_NAME": "soft_max_f32_mask_f16",
35 "REPLS": {
36 "MASK_TYPE" : "f16",
37 },
38 "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
39 },
40 {
41 "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
42 "REPLS": {
43 "MASK_TYPE" : "f16",
44 },
45 "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
46 },
47 {
48 "SHADER_NAME": "soft_max_f32_mask_f32_sink",
49 "REPLS": {
50 "MASK_TYPE" : "f32",
51 },
52 "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
53 },
54 {
55 "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
56 "REPLS": {
57 "MASK_TYPE" : "f32",
58 },
59 "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
60 },
61 {
62 "SHADER_NAME": "soft_max_f32_mask_f16_sink",
63 "REPLS": {
64 "MASK_TYPE" : "f16",
65 },
66 "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
67 },
68 {
69 "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
70 "REPLS": {
71 "MASK_TYPE" : "f16",
72 },
73 "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
74 }
75]
76#end(VARIANTS)
77
78#define(DECLS)
79
80#decl(BASE_BINDINGS)
81@group(0) @binding(1)
82var<storage, read_write> dst: array<f32>;
83
84@group(0) @binding(2)
85var<uniform> params: Params;
86#enddecl(BASE_BINDINGS)
87
88#decl(BASE_BINDINGS_INPLACE)
89@group(0) @binding(1)
90var<uniform> params: Params;
91#enddecl(BASE_BINDINGS_INPLACE)
92
93#decl(SINK_BINDINGS)
94@group(0) @binding(1)
95var<storage, read_write> sinks: array<f32>;
96
97@group(0) @binding(2)
98var<storage, read_write> dst: array<f32>;
99
100@group(0) @binding(3)
101var<uniform> params: Params;
102#enddecl(SINK_BINDINGS)
103
104#decl(SINK_BINDINGS_INPLACE)
105@group(0) @binding(1)
106var<storage, read_write> sinks: array<f32>;
107
108@group(0) @binding(2)
109var<uniform> params: Params;
110#enddecl(SINK_BINDINGS_INPLACE)
111
112#decl(MASK_BINDINGS)
113@group(0) @binding(1)
114var<storage, read_write> mask: array<{{MASK_TYPE}}>;
115
116@group(0) @binding(2)
117var<storage, read_write> dst: array<f32>;
118
119@group(0) @binding(3)
120var<uniform> params: Params;
121#enddecl(MASK_BINDINGS)
122
123#decl(MASK_BINDINGS_INPLACE)
124@group(0) @binding(1)
125var<storage, read_write> mask: array<{{MASK_TYPE}}>;
126
127@group(0) @binding(2)
128var<uniform> params: Params;
129#enddecl(MASK_BINDINGS_INPLACE)
130
131#decl(MASK_SINK_BINDINGS)
132@group(0) @binding(1)
133var<storage, read_write> mask: array<{{MASK_TYPE}}>;
134
135@group(0) @binding(2)
136var<storage, read_write> sinks: array<f32>;
137
138@group(0) @binding(3)
139var<storage, read_write> dst: array<f32>;
140
141@group(0) @binding(4)
142var<uniform> params: Params;
143#enddecl(MASK_SINK_BINDINGS)
144
145#decl(MASK_SINK_BINDINGS_INPLACE)
146@group(0) @binding(1)
147var<storage, read_write> mask: array<{{MASK_TYPE}}>;
148
149@group(0) @binding(2)
150var<storage, read_write> sinks: array<f32>;
151
152@group(0) @binding(3)
153var<uniform> params: Params;
154#enddecl(MASK_SINK_BINDINGS_INPLACE)
155
156#decl(NOT_INPLACE)
157fn inter_value(i: u32) -> f32 {
158 return dst[i];
159}
160
161fn update(i: u32, val: f32) {
162 dst[i] = val;
163}
164#enddecl(NOT_INPLACE)
165
166#decl(INPLACE)
167fn inter_value(i: u32) -> f32 {
168 return src[i];
169}
170
171fn update(i: u32, val: f32) {
172 src[i] = val;
173}
174#enddecl(INPLACE)
175
176#decl(NO_MASK)
177fn mask_val(i: u32) -> f32 {
178 return 0.0;
179}
180#enddecl(NO_MASK)
181
182#decl(MASK)
183fn mask_val(i: u32) -> f32 {
184 return f32(mask[i]);
185}
186#enddecl(MASK)
187
188#decl(NO_SINK)
189fn lower_max_bound(i2: u32) -> f32 {
190 return -1e30;
191}
192
193fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
194 return val;
195}
196#enddecl(NO_SINK)
197
198#decl(SINK)
199fn lower_max_bound(i2: u32) -> f32 {
200 return sinks[params.offset_sinks + i2];
201}
202
203fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
204 return val + exp(sinks[params.offset_sinks + i2] - max_val);
205}
206#enddecl(SINK)
207
208#end(DECLS)
209
210#define(SHADER)
211enable f16;
212
213struct Params {
214 offset_src0: u32,
215 offset_src1: u32,
216 offset_sinks: u32,
217 offset_dst: u32,
218
219 // Strides (in elements)
220 stride_src01: u32,
221 stride_src02: u32,
222 stride_src03: u32,
223
224 stride_src11: u32,
225 stride_src12: u32,
226 stride_src13: u32,
227
228 stride_dst1: u32,
229 stride_dst2: u32,
230 stride_dst3: u32,
231
232 // shape of src0/dst
233 ne: u32,
234 ne0: u32,
235 ne1: u32,
236 ne2: u32,
237
238 // shape of src1
239 ne12: u32,
240 ne13: u32,
241
242 scale: f32,
243 max_bias: f32,
244 n_head_log2: f32,
245 m0: f32,
246 m1: f32,
247};
248
249@group(0) @binding(0)
250var<storage, read_write> src: array<f32>;
251
252DECLS
253
254const CACHE_SIZE: u32 = 16;
255
256override wg_size: u32;
257var<workgroup> scratch: array<f32, wg_size>;
258
259@compute @workgroup_size(wg_size)
260fn main(@builtin(workgroup_id) wid: vec3<u32>,
261 @builtin(local_invocation_id) lid: vec3<u32>) {
262
263 var i = wid.x;
264 let i3 = i / (params.ne2 * params.ne1);
265 i = i % (params.ne2 * params.ne1);
266 let i2 = i / params.ne1;
267 let i1 = i % params.ne1;
268 let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
269 let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
270 let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
271 let elems = (params.ne0 + wg_size - 1) / wg_size;
272
273 let head = f32(i2);
274 let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
275
276 var cache: array<f32, CACHE_SIZE>;
277
278 var max_val = lower_max_bound(i2);
279 var col = lid.x;
280 for (var j: u32 = 0; j < elems; j++) {
281 if (col >= params.ne0) {
282 break;
283 }
284 let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
285 max_val = max(max_val, val);
286 if (col < CACHE_SIZE) {
287 cache[col] = val;
288 }
289 col += wg_size;
290 }
291
292 scratch[lid.x] = max_val;
293 workgroupBarrier();
294 var offset = wg_size / 2;
295 while (offset > 0) {
296 if (lid.x < offset) {
297 scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
298 }
299 offset = offset / 2;
300 workgroupBarrier();
301 }
302 let row_max = scratch[0];
303 workgroupBarrier();
304
305 var sum = 0.0f;
306 col = lid.x;
307 for (var j: u32 = 0; j < elems; j++) {
308 if (col >= params.ne0) {
309 break;
310 }
311 let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
312 cache[col], col < CACHE_SIZE);
313 let ex = exp(val - row_max);
314 sum += ex;
315 if (col < CACHE_SIZE) {
316 cache[col] = ex;
317 } else {
318 update(i_dst_row + col, ex);
319 }
320 col += wg_size;
321 }
322
323 scratch[lid.x] = sum;
324 workgroupBarrier();
325 offset = wg_size / 2;
326 while (offset > 0) {
327 if (lid.x < offset) {
328 scratch[lid.x] += scratch[lid.x + offset];
329 }
330 offset = offset / 2;
331 workgroupBarrier();
332 }
333 let row_sum = add_sinks(scratch[0], i2, row_max);
334
335 let sum_recip = 1.0 / row_sum;
336 col = lid.x;
337 for (var j: u32 = 0; j < elems; j++) {
338 if (col >= params.ne0) {
339 break;
340 }
341 update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
342 col += wg_size;
343 }
344}
345#end(SHADER)