1enable f16;
2
3struct Params {
4 ne: u32,
5
6 // offsets in elements
7 offset_src0: u32,
8 offset_src1: u32,
9 offset_dst: u32,
10
11 stride_src1_0: u32,
12 stride_src1_1: u32,
13 stride_src1_2: u32,
14 stride_src1_3: u32,
15
16 a_ne0: u32,
17 a_ne1: u32,
18 a_ne2: u32,
19
20 b_ne0: u32,
21 b_ne1: u32,
22 b_ne2: u32,
23 b_ne3: u32,
24};
25
26fn src1_index(_i: u32) -> u32 {
27 var i = _i;
28 let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
29 i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
30 let a_i2 = i / (params.a_ne1 * params.a_ne0);
31 i = i % (params.a_ne1 * params.a_ne0);
32 let a_i1 = i / params.a_ne0;
33 let a_i0 = i % params.a_ne0;
34
35 // handle repetition of b
36 // index loops back to the beginning and repeats after elements are exhausted = modulo
37 let b_i0 = a_i0 % params.b_ne0;
38 let b_i1 = a_i1 % params.b_ne1;
39 let b_i2 = a_i2 % params.b_ne2;
40 let b_i3 = a_i3 % params.b_ne3;
41
42 // compute index for position in b's flat array
43 return b_i0 * params.stride_src1_0 +
44 b_i1 * params.stride_src1_1 +
45 b_i2 * params.stride_src1_2 +
46 b_i3 * params.stride_src1_3;
47}
48
49#ifdef TYPE_F32
50#define DataType f32
51#endif
52#ifdef TYPE_F16
53#define DataType f16
54#endif
55
56@group(0) @binding(0)
57var<storage, read_write> src0: array<DataType>;
58
59@group(0) @binding(1)
60var<storage, read_write> src1 : array<DataType>;
61
62#ifdef INPLACE
63@group(0) @binding(2)
64var<uniform> params: Params;
65
66#elif defined(OVERLAP)
67@group(0) @binding(2)
68var<uniform> params: Params;
69
70#else
71@group(0) @binding(2)
72var<storage, read_write> dst: array<DataType>;
73
74@group(0) @binding(3)
75var<uniform> params: Params;
76#endif
77
78fn op(a: DataType, b: DataType) -> DataType {
79#ifdef OP_ADD
80 return a + b;
81#elif defined(OP_SUB)
82 return a - b;
83#elif defined(OP_MUL)
84 return a * b;
85#elif defined(OP_DIV)
86 return a / b;
87#endif
88}
89
90fn update(dst_i: u32, src0_i: u32, src1_i: u32){
91 let result = op(src0[src0_i], src1[src1_i]);
92
93#ifdef INPLACE
94 src0[dst_i] = result;
95#elif defined(OVERLAP)
96 src1[dst_i] = result;
97#else
98 dst[dst_i] = result;
99#endif
100}
101
102@compute @workgroup_size(WG_SIZE)
103fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
104 if (gid.x < params.ne) {
105 update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
106 }
107}