1#define(VARIANTS)
  2
  3[
  4  {
  5    "REPLS": {
  6      "SRC_TYPE": "f32",
  7      "DST_TYPE": "f32"
  8    }
  9  },
 10  {
 11    "REPLS": {
 12      "SRC_TYPE": "f32",
 13      "DST_TYPE": "i32"
 14    }
 15  },
 16  {
 17    "REPLS": {
 18      "SRC_TYPE": "f32",
 19      "DST_TYPE": "f16"
 20    }
 21  },
 22  {
 23    "REPLS": {
 24      "SRC_TYPE": "f16",
 25      "DST_TYPE": "f16"
 26    }
 27  },
 28  {
 29    "REPLS": {
 30      "SRC_TYPE": "f16",
 31      "DST_TYPE": "f32"
 32    }
 33  }
 34]
 35
 36#end(VARIANTS)
 37
 38#define(SHADER)
 39enable f16;
 40
 41@group(0) @binding(0)
 42var<storage, read_write> src: array<{{SRC_TYPE}}>;
 43
 44@group(0) @binding(1)
 45var<storage, read_write> dst: array<{{DST_TYPE}}>;
 46
 47struct Params {
 48    ne: u32,            // total number of elements
 49    offset_src: u32,    // in elements
 50    offset_dst: u32,    // in elements
 51
 52    // Strides (in elements) — may be permuted
 53    stride_src0: u32,
 54    stride_src1: u32,
 55    stride_src2: u32,
 56    stride_src3: u32,
 57
 58    stride_dst0: u32,
 59    stride_dst1: u32,
 60    stride_dst2: u32,
 61    stride_dst3: u32,
 62
 63    // Logical shapes
 64    src_ne0: u32,
 65    src_ne1: u32,
 66    src_ne2: u32,
 67
 68    dst_ne0: u32,
 69    dst_ne1: u32,
 70    dst_ne2: u32
 71};
 72
 73@group(0) @binding(2)
 74var<uniform> params: Params;
 75
 76override wg_size: u32;
 77@compute @workgroup_size(wg_size)
 78fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
 79    if (gid.x >= params.ne) {
 80        return;
 81    }
 82
 83    var i = gid.x;
 84    let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
 85    i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
 86    let i2 = i / (params.src_ne1 * params.src_ne0);
 87    i = i % (params.src_ne1 * params.src_ne0);
 88    let i1 = i / params.src_ne0;
 89    let i0 = i % params.src_ne0;
 90
 91    var j = gid.x;
 92    let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
 93    j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
 94    let j2 = j / (params.dst_ne1 * params.dst_ne0);
 95    j = j % (params.dst_ne1 * params.dst_ne0);
 96    let j1 = j / params.dst_ne0;
 97    let j0 = j % params.dst_ne0;
 98
 99    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
100                  i2 * params.stride_src2 + i3 * params.stride_src3;
101
102    let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
103                  j2 * params.stride_dst2 + j3 * params.stride_dst3;
104
105    dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
106}
107#end(SHADER)