diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl new file mode 100644 index 0000000..b5e93b8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | |||
| @@ -0,0 +1,107 @@ | |||
| 1 | #define(VARIANTS) | ||
| 2 | |||
| 3 | [ | ||
| 4 | { | ||
| 5 | "REPLS": { | ||
| 6 | "SRC_TYPE": "f32", | ||
| 7 | "DST_TYPE": "f32" | ||
| 8 | } | ||
| 9 | }, | ||
| 10 | { | ||
| 11 | "REPLS": { | ||
| 12 | "SRC_TYPE": "f32", | ||
| 13 | "DST_TYPE": "i32" | ||
| 14 | } | ||
| 15 | }, | ||
| 16 | { | ||
| 17 | "REPLS": { | ||
| 18 | "SRC_TYPE": "f32", | ||
| 19 | "DST_TYPE": "f16" | ||
| 20 | } | ||
| 21 | }, | ||
| 22 | { | ||
| 23 | "REPLS": { | ||
| 24 | "SRC_TYPE": "f16", | ||
| 25 | "DST_TYPE": "f16" | ||
| 26 | } | ||
| 27 | }, | ||
| 28 | { | ||
| 29 | "REPLS": { | ||
| 30 | "SRC_TYPE": "f16", | ||
| 31 | "DST_TYPE": "f32" | ||
| 32 | } | ||
| 33 | } | ||
| 34 | ] | ||
| 35 | |||
| 36 | #end(VARIANTS) | ||
| 37 | |||
| 38 | #define(SHADER) | ||
| 39 | enable f16; | ||
| 40 | |||
| 41 | @group(0) @binding(0) | ||
| 42 | var<storage, read_write> src: array<{{SRC_TYPE}}>; | ||
| 43 | |||
| 44 | @group(0) @binding(1) | ||
| 45 | var<storage, read_write> dst: array<{{DST_TYPE}}>; | ||
| 46 | |||
| 47 | struct Params { | ||
| 48 | ne: u32, // total number of elements | ||
| 49 | offset_src: u32, // in elements | ||
| 50 | offset_dst: u32, // in elements | ||
| 51 | |||
| 52 | // Strides (in elements) — may be permuted | ||
| 53 | stride_src0: u32, | ||
| 54 | stride_src1: u32, | ||
| 55 | stride_src2: u32, | ||
| 56 | stride_src3: u32, | ||
| 57 | |||
| 58 | stride_dst0: u32, | ||
| 59 | stride_dst1: u32, | ||
| 60 | stride_dst2: u32, | ||
| 61 | stride_dst3: u32, | ||
| 62 | |||
| 63 | // Logical shapes | ||
| 64 | src_ne0: u32, | ||
| 65 | src_ne1: u32, | ||
| 66 | src_ne2: u32, | ||
| 67 | |||
| 68 | dst_ne0: u32, | ||
| 69 | dst_ne1: u32, | ||
| 70 | dst_ne2: u32 | ||
| 71 | }; | ||
| 72 | |||
| 73 | @group(0) @binding(2) | ||
| 74 | var<uniform> params: Params; | ||
| 75 | |||
| 76 | override wg_size: u32; | ||
| 77 | @compute @workgroup_size(wg_size) | ||
| 78 | fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | ||
| 79 | if (gid.x >= params.ne) { | ||
| 80 | return; | ||
| 81 | } | ||
| 82 | |||
| 83 | var i = gid.x; | ||
| 84 | let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); | ||
| 85 | i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); | ||
| 86 | let i2 = i / (params.src_ne1 * params.src_ne0); | ||
| 87 | i = i % (params.src_ne1 * params.src_ne0); | ||
| 88 | let i1 = i / params.src_ne0; | ||
| 89 | let i0 = i % params.src_ne0; | ||
| 90 | |||
| 91 | var j = gid.x; | ||
| 92 | let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); | ||
| 93 | j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); | ||
| 94 | let j2 = j / (params.dst_ne1 * params.dst_ne0); | ||
| 95 | j = j % (params.dst_ne1 * params.dst_ne0); | ||
| 96 | let j1 = j / params.dst_ne0; | ||
| 97 | let j0 = j % params.dst_ne0; | ||
| 98 | |||
| 99 | let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + | ||
| 100 | i2 * params.stride_src2 + i3 * params.stride_src3; | ||
| 101 | |||
| 102 | let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + | ||
| 103 | j2 * params.stride_dst2 + j3 * params.stride_dst3; | ||
| 104 | |||
| 105 | dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); | ||
| 106 | } | ||
| 107 | #end(SHADER) | ||
