1@group(0) @binding(0)
2var<storage, read_write> src: array<f32>;
3
4@group(0) @binding(1)
5var<storage, read_write> dst: array<f32>;
6
7struct Params {
8 offset_src: u32, // in elements
9 offset_dst: u32, // in elements
10
11 // Strides (in elements)
12 stride_src1: u32,
13 stride_src2: u32,
14 stride_src3: u32,
15
16 ne0: u32,
17 ne1: u32,
18 ne2: u32
19};
20
21@group(0) @binding(2)
22var<uniform> params: Params;
23
24var<workgroup> shared_sum: array<f32, WG_SIZE>;
25
26@compute @workgroup_size(WG_SIZE)
27fn main(@builtin(workgroup_id) wid: vec3<u32>,
28 @builtin(local_invocation_id) lid: vec3<u32>) {
29
30 var i = wid.x;
31 let i3 = i / (params.ne2 * params.ne1);
32 i = i % (params.ne2 * params.ne1);
33 let i2 = i / params.ne1;
34 let i1 = i % params.ne1;
35 let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
36 var local_sum: f32 = 0.0;
37 for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
38 local_sum += src[i_src_row + col];
39 }
40 shared_sum[lid.x] = local_sum;
41 workgroupBarrier();
42 // reduce within workgroup
43 var offset: u32 = WG_SIZE >> 1;
44 while (offset > 0) {
45 if (lid.x < offset) {
46 shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
47 }
48 workgroupBarrier();
49 offset >>= 1;
50 }
51
52 if (lid.x == 0) {
53 dst[params.offset_dst + wid.x] = shared_sum[0];
54 }
55}