1enable f16;
  2
  3#ifdef DST_F32
  4#define DST_INNER_TYPE f32
  5#else
  6#define DST_INNER_TYPE f16
  7#endif
  8
  9#ifdef VEC4
 10#define SRC_TYPE vec4<f32>
 11#define DST_TYPE vec4<DST_INNER_TYPE>
 12#define VEC_SIZE 4
 13#else
 14#define SRC_TYPE f32
 15#define DST_TYPE DST_INNER_TYPE
 16#define VEC_SIZE 1
 17#endif
 18
 19@group(0) @binding(0)
 20var<storage, read_write> src: array<SRC_TYPE>;
 21
 22@group(0) @binding(1)
 23var<storage, read_write> idx: array<u32>;
 24
 25@group(0) @binding(2)
 26var<storage, read_write> dst: array<DST_TYPE>;
 27
 28#ifdef I64_IDX
 29@group(0) @binding(3)
 30var<storage, read_write> error: atomic<u32>;
 31#define PARAMS_BINDING 4
 32#else
 33#define PARAMS_BINDING 3
 34#endif
 35
 36struct Params {
 37    offset_src: u32, // in elements
 38    offset_idx: u32, // in elements
 39    offset_dst: u32, // in elements
 40
 41    // Strides (in elements)
 42    stride_src1: u32,
 43    stride_src2: u32,
 44    stride_src3: u32,
 45
 46    stride_idx0: u32,
 47    stride_idx1: u32,
 48    stride_idx2: u32,
 49
 50    stride_dst1: u32,
 51    stride_dst2: u32,
 52    stride_dst3: u32,
 53
 54    // Shape of src
 55    ne0: u32,
 56    n_rows: u32,
 57    ne2: u32,
 58    ne3: u32,
 59
 60    // Shape of idx
 61    idx1: u32,
 62    idx2: u32,
 63};
 64
 65@group(0) @binding(PARAMS_BINDING)
 66var<uniform> params: Params;
 67
 68@compute @workgroup_size(WG_SIZE)
 69fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
 70    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
 71        return;
 72    }
 73
 74    // getting the row from gid
 75    let elems_per_row = params.ne0 / VEC_SIZE;
 76    var i = gid.x / elems_per_row;
 77
 78    let i_src3 = i / (params.ne2 * params.n_rows);
 79
 80    i = i % (params.ne2 * params.n_rows);
 81    let i_src2 = i / params.n_rows;
 82    let i_src1 = i % params.n_rows;
 83
 84    let i_idx2 = i_src3 % params.idx2;
 85    let i_idx1 = i_src2 % params.idx1;
 86    let i_idx0 = i_src1;
 87
 88#ifdef I64_IDX
 89    let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
 90
 91    let idx_val = idx[idx_high];
 92    let idx_low_val = idx[idx_high + 1];
 93
 94    if (idx_low_val != 0) {
 95        // Upper bits of index are not zero, output will be incorrect
 96        atomicStore(&error, 1);
 97        return;
 98    }
 99#else
100    let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
101    let idx_val = idx[idx_i];
102#endif
103
104    let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
105    let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
106
107    let col_idx = (gid.x % elems_per_row);
108    dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
109}