1enable f16;
  2
  3struct Params {
  4    ne: u32,
  5
  6    // offsets in elements
  7    offset_src0: u32,
  8    offset_src1: u32,
  9    offset_dst: u32,
 10
 11    stride_src1_0: u32,
 12    stride_src1_1: u32,
 13    stride_src1_2: u32,
 14    stride_src1_3: u32,
 15
 16    a_ne0: u32,
 17    a_ne1: u32,
 18    a_ne2: u32,
 19
 20    b_ne0: u32,
 21    b_ne1: u32,
 22    b_ne2: u32,
 23    b_ne3: u32,
 24};
 25
 26fn src1_index(_i: u32) -> u32 {
 27    var i = _i;
 28    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
 29    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
 30    let a_i2 = i / (params.a_ne1 * params.a_ne0);
 31    i = i % (params.a_ne1 * params.a_ne0);
 32    let a_i1 = i / params.a_ne0;
 33    let a_i0 = i % params.a_ne0;
 34
 35    // handle repetition of b
 36    // index loops back to the beginning and repeats after elements are exhausted = modulo
 37    let b_i0 = a_i0 % params.b_ne0;
 38    let b_i1 = a_i1 % params.b_ne1;
 39    let b_i2 = a_i2 % params.b_ne2;
 40    let b_i3 = a_i3 % params.b_ne3;
 41
 42    // compute index for position in b's flat array
 43    return b_i0 * params.stride_src1_0 +
 44           b_i1 * params.stride_src1_1 +
 45           b_i2 * params.stride_src1_2 +
 46           b_i3 * params.stride_src1_3;
 47}
 48
 49#ifdef TYPE_F32
 50#define DataType f32
 51#endif
 52#ifdef TYPE_F16
 53#define DataType f16
 54#endif
 55
 56@group(0) @binding(0)
 57var<storage, read_write> src0: array<DataType>;
 58
 59@group(0) @binding(1)
 60var<storage, read_write> src1 : array<DataType>;
 61
 62#ifdef INPLACE
 63@group(0) @binding(2)
 64var<uniform> params: Params;
 65
 66#elif defined(OVERLAP)
 67@group(0) @binding(2)
 68var<uniform> params: Params;
 69
 70#else
 71@group(0) @binding(2)
 72var<storage, read_write> dst: array<DataType>;
 73
 74@group(0) @binding(3)
 75var<uniform> params: Params;
 76#endif
 77
 78fn op(a: DataType, b: DataType) -> DataType {
 79#ifdef OP_ADD
 80    return a + b;
 81#elif defined(OP_SUB)
 82    return a - b;
 83#elif defined(OP_MUL)
 84    return a * b;
 85#elif defined(OP_DIV)
 86    return a / b;
 87#endif
 88}
 89
 90fn update(dst_i: u32, src0_i: u32, src1_i: u32){
 91    let result = op(src0[src0_i], src1[src1_i]);
 92
 93#ifdef INPLACE
 94    src0[dst_i] = result;
 95#elif defined(OVERLAP)
 96    src1[dst_i] = result;
 97#else
 98    dst[dst_i] = result;
 99#endif
100}
101
102@compute @workgroup_size(WG_SIZE)
103fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
104    if (gid.x < params.ne) {
105        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
106    }
107}