diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl new file mode 100644 index 0000000..ea63b9a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl | |||
| @@ -0,0 +1,86 @@ | |||
| 1 | @group(0) @binding(0) | ||
| 2 | var<storage, read_write> src: array<f32>; | ||
| 3 | |||
| 4 | @group(0) @binding(1) | ||
| 5 | var<storage, read_write> dst: array<f32>; | ||
| 6 | |||
| 7 | struct Params { | ||
| 8 | ne: u32, // total number of elements | ||
| 9 | offset_src: u32, // in elements | ||
| 10 | offset_dst: u32, // in elements | ||
| 11 | |||
| 12 | // Strides (in elements) | ||
| 13 | stride_src0: u32, | ||
| 14 | stride_src1: u32, | ||
| 15 | stride_src2: u32, | ||
| 16 | stride_src3: u32, | ||
| 17 | |||
| 18 | // Logical shapes | ||
| 19 | src_ne0: u32, | ||
| 20 | src_ne1: u32, | ||
| 21 | src_ne2: u32, | ||
| 22 | src_ne3: u32, | ||
| 23 | |||
| 24 | dst_ne0: u32, | ||
| 25 | dst_ne1: u32, | ||
| 26 | dst_ne2: u32, | ||
| 27 | dst_ne3: u32, | ||
| 28 | |||
| 29 | // Pad sizes (in elements) | ||
| 30 | lp0: u32, | ||
| 31 | rp0: u32, | ||
| 32 | lp1: u32, | ||
| 33 | rp1: u32, | ||
| 34 | lp2: u32, | ||
| 35 | rp2: u32, | ||
| 36 | lp3: u32, | ||
| 37 | rp3: u32, | ||
| 38 | }; | ||
| 39 | |||
| 40 | @group(0) @binding(2) | ||
| 41 | var<uniform> params: Params; | ||
| 42 | |||
| 43 | fn wrap_around(idx: i32, n: u32) -> u32 { | ||
| 44 | return u32(idx + i32(n)) % n; | ||
| 45 | } | ||
| 46 | |||
| 47 | @compute @workgroup_size(WG_SIZE) | ||
| 48 | fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | ||
| 49 | if (gid.x >= params.ne) { | ||
| 50 | return; | ||
| 51 | } | ||
| 52 | |||
| 53 | var i = gid.x; | ||
| 54 | let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0; | ||
| 55 | let i3 = i / dst_plane; | ||
| 56 | i = i % dst_plane; | ||
| 57 | let i2 = i / (params.dst_ne1 * params.dst_ne0); | ||
| 58 | i = i % (params.dst_ne1 * params.dst_ne0); | ||
| 59 | let i1 = i / params.dst_ne0; | ||
| 60 | let i0 = i % params.dst_ne0; | ||
| 61 | |||
| 62 | var value: f32 = 0.0; | ||
| 63 | |||
| 64 | #ifdef CIRCULAR | ||
| 65 | let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0); | ||
| 66 | let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1); | ||
| 67 | let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2); | ||
| 68 | let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3); | ||
| 69 | let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 + | ||
| 70 | ci2 * params.stride_src2 + ci3 * params.stride_src3; | ||
| 71 | value = src[params.offset_src + circular_src_idx]; | ||
| 72 | #else | ||
| 73 | let is_src = | ||
| 74 | (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) && | ||
| 75 | (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) && | ||
| 76 | (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) && | ||
| 77 | (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3); | ||
| 78 | if (is_src) { | ||
| 79 | let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 + | ||
| 80 | (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3; | ||
| 81 | value = src[params.offset_src + src_idx]; | ||
| 82 | } | ||
| 83 | #endif | ||
| 84 | |||
| 85 | dst[params.offset_dst + gid.x] = value; | ||
| 86 | } | ||
