1@group(0) @binding(0)
 2var<storage, read_write> src: array<f32>;
 3
 4@group(0) @binding(1)
 5var<storage, read_write> dst: array<f32>;
 6
 7struct Params {
 8    ne: u32,            // total number of elements
 9    offset_src: u32,    // in elements
10    offset_dst: u32,    // in elements
11
12    // Strides (in elements)
13    stride_src0: u32,
14    stride_src1: u32,
15    stride_src2: u32,
16    stride_src3: u32,
17
18    // Logical shapes
19    src_ne0: u32,
20    src_ne1: u32,
21    src_ne2: u32,
22    src_ne3: u32,
23
24    dst_ne0: u32,
25    dst_ne1: u32,
26    dst_ne2: u32,
27    dst_ne3: u32,
28
29    // Pad sizes (in elements)
30    lp0: u32,
31    rp0: u32,
32    lp1: u32,
33    rp1: u32,
34    lp2: u32,
35    rp2: u32,
36    lp3: u32,
37    rp3: u32,
38};
39
40@group(0) @binding(2)
41var<uniform> params: Params;
42
43fn wrap_around(idx: i32, n: u32) -> u32 {
44    return u32(idx + i32(n)) % n;
45}
46
47@compute @workgroup_size(WG_SIZE)
48fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
49    if (gid.x >= params.ne) {
50        return;
51    }
52
53    var i = gid.x;
54    let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
55    let i3 = i / dst_plane;
56    i = i % dst_plane;
57    let i2 = i / (params.dst_ne1 * params.dst_ne0);
58    i = i % (params.dst_ne1 * params.dst_ne0);
59    let i1 = i / params.dst_ne0;
60    let i0 = i % params.dst_ne0;
61
62    var value: f32 = 0.0;
63
64#ifdef CIRCULAR
65    let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
66    let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
67    let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
68    let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
69    let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
70                           ci2 * params.stride_src2 + ci3 * params.stride_src3;
71    value = src[params.offset_src + circular_src_idx];
72#else
73    let is_src =
74        (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
75        (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
76        (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
77        (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
78    if (is_src) {
79        let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
80                      (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
81        value = src[params.offset_src + src_idx];
82    }
83#endif
84
85    dst[params.offset_dst + gid.x] = value;
86}