1@group(0) @binding(0)
 2var<storage, read_write> src: array<f32>;
 3
 4@group(0) @binding(1)
 5var<storage, read_write> dst: array<f32>;
 6
 7struct Params {
 8    offset_src: u32, // in elements
 9    offset_dst: u32, // in elements
10
11    // Strides (in elements)
12    stride_src1: u32,
13    stride_src2: u32,
14    stride_src3: u32,
15
16    ne0: u32,
17    ne1: u32,
18    ne2: u32
19};
20
21@group(0) @binding(2)
22var<uniform> params: Params;
23
24var<workgroup> shared_sum: array<f32, WG_SIZE>;
25
26@compute @workgroup_size(WG_SIZE)
27fn main(@builtin(workgroup_id) wid: vec3<u32>,
28        @builtin(local_invocation_id) lid: vec3<u32>) {
29
30    var i = wid.x;
31    let i3 = i / (params.ne2 * params.ne1);
32    i = i % (params.ne2 * params.ne1);
33    let i2 = i / params.ne1;
34    let i1 = i % params.ne1;
35    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
36    var local_sum: f32 = 0.0;
37    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
38        local_sum += src[i_src_row + col];
39    }
40    shared_sum[lid.x] = local_sum;
41    workgroupBarrier();
42    // reduce within workgroup
43    var offset: u32 = WG_SIZE >> 1;
44    while (offset > 0) {
45        if (lid.x < offset) {
46            shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
47        }
48        workgroupBarrier();
49        offset >>= 1;
50    }
51
52    if (lid.x == 0) {
53        dst[params.offset_dst + wid.x] = shared_sum[0];
54    }
55}