diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl new file mode 100644 index 0000000..040e80d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl | |||
| @@ -0,0 +1,90 @@ | |||
| 1 | #define(VARIANTS) | ||
| 2 | |||
| 3 | [ | ||
| 4 | { | ||
| 5 | "SHADER_NAME": "scale_f32", | ||
| 6 | "DECLS": ["NOT_INPLACE"] | ||
| 7 | }, | ||
| 8 | { | ||
| 9 | "SHADER_NAME": "scale_f32_inplace", | ||
| 10 | "DECLS": ["INPLACE"] | ||
| 11 | } | ||
| 12 | ] | ||
| 13 | |||
| 14 | #end(VARIANTS) | ||
| 15 | |||
| 16 | #define(DECLS) | ||
| 17 | |||
| 18 | #decl(NOT_INPLACE) | ||
| 19 | @group(0) @binding(1) | ||
| 20 | var<storage, read_write> dst: array<f32>; | ||
| 21 | |||
| 22 | @group(0) @binding(2) | ||
| 23 | var<uniform> params: Params; | ||
| 24 | |||
| 25 | fn store_scale(val: f32, offset: u32) { | ||
| 26 | dst[offset] = val; | ||
| 27 | } | ||
| 28 | #enddecl(NOT_INPLACE) | ||
| 29 | |||
| 30 | #decl(INPLACE) | ||
| 31 | @group(0) @binding(1) | ||
| 32 | var<uniform> params: Params; | ||
| 33 | |||
| 34 | fn store_scale(val: f32, offset: u32) { | ||
| 35 | src[offset] = val; | ||
| 36 | } | ||
| 37 | #enddecl(INPLACE) | ||
| 38 | |||
| 39 | #end(DECLS) | ||
| 40 | |||
| 41 | #define(SHADER) | ||
| 42 | |||
| 43 | struct Params { | ||
| 44 | offset_src: u32, | ||
| 45 | offset_dst: u32, | ||
| 46 | |||
| 47 | // Strides (in elements) | ||
| 48 | stride_src1: u32, | ||
| 49 | stride_src2: u32, | ||
| 50 | stride_src3: u32, | ||
| 51 | |||
| 52 | stride_dst1: u32, | ||
| 53 | stride_dst2: u32, | ||
| 54 | stride_dst3: u32, | ||
| 55 | |||
| 56 | ne: u32, | ||
| 57 | ne0: u32, | ||
| 58 | ne1: u32, | ||
| 59 | ne2: u32, | ||
| 60 | |||
| 61 | scale: f32, | ||
| 62 | bias: f32 | ||
| 63 | }; | ||
| 64 | |||
| 65 | @group(0) @binding(0) | ||
| 66 | var<storage, read_write> src: array<f32>; | ||
| 67 | |||
| 68 | DECLS | ||
| 69 | |||
| 70 | override wg_size: u32; | ||
| 71 | @compute @workgroup_size(wg_size) | ||
| 72 | fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | ||
| 73 | if (gid.x >= params.ne) { | ||
| 74 | return; | ||
| 75 | } | ||
| 76 | |||
| 77 | var i = gid.x; | ||
| 78 | let i3 = i / (params.ne2 * params.ne1 * params.ne0); | ||
| 79 | i = i % (params.ne2 * params.ne1 * params.ne0); | ||
| 80 | let i2 = i / (params.ne1 * params.ne0); | ||
| 81 | i = i % (params.ne1 * params.ne0); | ||
| 82 | let i1 = i / params.ne0; | ||
| 83 | let i0 = i % params.ne0; | ||
| 84 | |||
| 85 | let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; | ||
| 86 | let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; | ||
| 87 | |||
| 88 | store_scale(src[i_src] * params.scale + params.bias, i_dst); | ||
| 89 | } | ||
| 90 | #end(SHADER) | ||
