1
2kernel void kernel_sum_rows_f32(
3 global float * src0,
4 ulong offset0,
5 global float * dst,
6 ulong offsetd,
7 int ne00,
8 int ne01,
9 int ne02,
10 int ne03,
11 ulong nb01,
12 ulong nb02,
13 ulong nb03,
14 ulong nb1,
15 ulong nb2,
16 ulong nb3
17) {
18 src0 = (global float *)((global char *)src0 + offset0);
19 dst = (global float *)((global char *)dst + offsetd);
20
21 int i3 = get_global_id(2);
22 int i2 = get_global_id(1);
23 int i1 = get_global_id(0);
24
25 if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
26 return;
27 }
28
29 global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
30 global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
31
32 float row_sum = 0;
33
34 for (int i0 = 0; i0 < ne00; i0++) {
35 row_sum += src_row[i0];
36 }
37
38 dst_row[0] = row_sum;
39}