1
 2kernel void kernel_sum_rows_f32(
 3    global float *  src0,
 4    ulong           offset0,
 5    global float *  dst,
 6    ulong           offsetd,
 7    int             ne00,
 8    int             ne01,
 9    int             ne02,
10    int             ne03,
11    ulong           nb01,
12    ulong           nb02,
13    ulong           nb03,
14    ulong           nb1,
15    ulong           nb2,
16    ulong           nb3
17) {
18    src0 = (global float *)((global char *)src0 + offset0);
19    dst  = (global float *)((global char *)dst  + offsetd);
20
21    int i3 = get_global_id(2);
22    int i2 = get_global_id(1);
23    int i1 = get_global_id(0);
24
25    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
26        return;
27    }
28
29    global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
30    global float * dst_row = (global float *) ((global char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
31
32    float row_sum = 0;
33
34    for (int i0 = 0; i0 < ne00; i0++) {
35        row_sum += src_row[i0];
36    }
37
38    dst_row[0] = row_sum;
39}