1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3//------------------------------------------------------------------------------
  4// add
  5//------------------------------------------------------------------------------
  6
  7// general-purpose kernel for addition of two tensors
  8// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
  9// cons: not very efficient
 10kernel void kernel_add(
 11        global char * src0,
 12        ulong  offset0,
 13        global char * src1,
 14        ulong  offset1,
 15        global char * dst,
 16        ulong  offsetd,
 17        int   ne00,
 18        int   ne01,
 19        int   ne02,
 20        int   ne03,
 21        ulong nb00,
 22        ulong nb01,
 23        ulong nb02,
 24        ulong nb03,
 25        int   ne10,
 26        int   ne11,
 27        int   ne12,
 28        int   ne13,
 29        ulong nb10,
 30        ulong nb11,
 31        ulong nb12,
 32        ulong nb13,
 33        int   ne0,
 34        int   ne1,
 35        int   ne2,
 36        int   ne3,
 37        ulong nb0,
 38        ulong nb1,
 39        ulong nb2,
 40        ulong nb3
 41) {
 42    src0 = src0 + offset0;
 43    src1 = src1 + offset1;
 44    dst = dst + offsetd;
 45
 46    int i03 = get_group_id(2);
 47    int i02 = get_group_id(1);
 48    int i01 = get_group_id(0);
 49
 50    int i13 = i03 % ne13;
 51    int i12 = i02 % ne12;
 52    int i11 = i01 % ne11;
 53
 54    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
 55    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
 56    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 57
 58    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
 59        const int i10 = i0 % ne10;
 60        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
 61    }
 62}
 63
 64// assumption: src1 is a row
 65// broadcast src1 into src0
 66kernel void kernel_add_row(
 67        global float4 * src0,
 68        ulong  offset0,
 69        global float4 * src1,
 70        ulong  offset1,
 71        global float4 * dst,
 72        ulong  offsetd,
 73        int ne
 74) {
 75    src0 = (global float4*)((global char*)src0 + offset0);
 76    src1 = (global float4*)((global char*)src1 + offset1);
 77    dst = (global float4*)((global char*)dst + offsetd);
 78
 79    // This performs better than using %.
 80    uint gid = get_global_id(0);
 81    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
 82    dst[gid] = src0[gid] + src1[idx1];
 83}
 84
 85kernel void kernel_add_f16(
 86        global char * src0,
 87        ulong  offset0,
 88        global char * src1,
 89        ulong  offset1,
 90        global char * dst,
 91        ulong  offsetd,
 92        int   ne00,
 93        int   ne01,
 94        int   ne02,
 95        int   ne03,
 96        ulong nb00,
 97        ulong nb01,
 98        ulong nb02,
 99        ulong nb03,
100        int   ne10,
101        int   ne11,
102        int   ne12,
103        int   ne13,
104        ulong nb10,
105        ulong nb11,
106        ulong nb12,
107        ulong nb13,
108        int   ne0,
109        int   ne1,
110        int   ne2,
111        int   ne3,
112        ulong nb0,
113        ulong nb1,
114        ulong nb2,
115        ulong nb3,
116        int type_src0,
117        int type_src1
118) {
119    src0 = src0 + offset0;
120    src1 = src1 + offset1;
121    dst = dst + offsetd;
122
123    int i03 = get_group_id(2);
124    int i02 = get_group_id(1);
125    int i01 = get_group_id(0);
126
127    int i13 = i03 % ne13;
128    int i12 = i02 % ne12;
129    int i11 = i01 % ne11;
130
131    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
132    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
133    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
134
135    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
136        const int i10 = i0 % ne10;
137
138        half v0, v1;
139        if (type_src0 == 1) {
140            v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
141        } else {
142            v0 = *((global half *)(src0_ptr + i0*nb00));
143        }
144
145        if (type_src1 == 1) {
146            v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
147        } else {
148            v1 = *((global half *)(src1_ptr + i10*nb10));
149        }
150
151        *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
152    }
153}
154
155kernel void kernel_add_row_f16(
156        global char * src0,
157        ulong  offset0,
158        global char * src1,
159        ulong  offset1,
160        global half4 * dst,
161        ulong  offsetd,
162        int ne,
163        int type_src0,
164        int type_src1
165) {
166    dst = (global half4*)((global char*)dst + offsetd);
167
168    // This performs better than using %.
169    uint gid = get_global_id(0);
170    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
171
172    half4 v0, v1;
173    if (type_src0 == 1) {
174        global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
175        v0 = convert_half4(src0_f32[gid]);
176    } else {
177        global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
178        v0 = src0_f16[gid];
179    }
180
181    if (type_src1 == 1) {
182        global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
183        v1 = convert_half4(src1_f32[idx1]);
184    } else {
185        global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
186        v1 = src1_f16[idx1];
187    }
188
189    dst[gid] = v0 + v1;
190}