1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3//------------------------------------------------------------------------------
4// add
5//------------------------------------------------------------------------------
6
7// general-purpose kernel for addition of two tensors
8// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
9// cons: not very efficient
10kernel void kernel_add(
11 global char * src0,
12 ulong offset0,
13 global char * src1,
14 ulong offset1,
15 global char * dst,
16 ulong offsetd,
17 int ne00,
18 int ne01,
19 int ne02,
20 int ne03,
21 ulong nb00,
22 ulong nb01,
23 ulong nb02,
24 ulong nb03,
25 int ne10,
26 int ne11,
27 int ne12,
28 int ne13,
29 ulong nb10,
30 ulong nb11,
31 ulong nb12,
32 ulong nb13,
33 int ne0,
34 int ne1,
35 int ne2,
36 int ne3,
37 ulong nb0,
38 ulong nb1,
39 ulong nb2,
40 ulong nb3
41) {
42 src0 = src0 + offset0;
43 src1 = src1 + offset1;
44 dst = dst + offsetd;
45
46 int i03 = get_group_id(2);
47 int i02 = get_group_id(1);
48 int i01 = get_group_id(0);
49
50 int i13 = i03 % ne13;
51 int i12 = i02 % ne12;
52 int i11 = i01 % ne11;
53
54 global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
55 global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
56 global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
57
58 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
59 const int i10 = i0 % ne10;
60 *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
61 }
62}
63
64// assumption: src1 is a row
65// broadcast src1 into src0
66kernel void kernel_add_row(
67 global float4 * src0,
68 ulong offset0,
69 global float4 * src1,
70 ulong offset1,
71 global float4 * dst,
72 ulong offsetd,
73 int ne
74) {
75 src0 = (global float4*)((global char*)src0 + offset0);
76 src1 = (global float4*)((global char*)src1 + offset1);
77 dst = (global float4*)((global char*)dst + offsetd);
78
79 // This performs better than using %.
80 uint gid = get_global_id(0);
81 uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
82 dst[gid] = src0[gid] + src1[idx1];
83}
84
85kernel void kernel_add_f16(
86 global char * src0,
87 ulong offset0,
88 global char * src1,
89 ulong offset1,
90 global char * dst,
91 ulong offsetd,
92 int ne00,
93 int ne01,
94 int ne02,
95 int ne03,
96 ulong nb00,
97 ulong nb01,
98 ulong nb02,
99 ulong nb03,
100 int ne10,
101 int ne11,
102 int ne12,
103 int ne13,
104 ulong nb10,
105 ulong nb11,
106 ulong nb12,
107 ulong nb13,
108 int ne0,
109 int ne1,
110 int ne2,
111 int ne3,
112 ulong nb0,
113 ulong nb1,
114 ulong nb2,
115 ulong nb3,
116 int type_src0,
117 int type_src1
118) {
119 src0 = src0 + offset0;
120 src1 = src1 + offset1;
121 dst = dst + offsetd;
122
123 int i03 = get_group_id(2);
124 int i02 = get_group_id(1);
125 int i01 = get_group_id(0);
126
127 int i13 = i03 % ne13;
128 int i12 = i02 % ne12;
129 int i11 = i01 % ne11;
130
131 global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
132 global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
133 global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
134
135 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
136 const int i10 = i0 % ne10;
137
138 half v0, v1;
139 if (type_src0 == 1) {
140 v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
141 } else {
142 v0 = *((global half *)(src0_ptr + i0*nb00));
143 }
144
145 if (type_src1 == 1) {
146 v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
147 } else {
148 v1 = *((global half *)(src1_ptr + i10*nb10));
149 }
150
151 *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
152 }
153}
154
155kernel void kernel_add_row_f16(
156 global char * src0,
157 ulong offset0,
158 global char * src1,
159 ulong offset1,
160 global half4 * dst,
161 ulong offsetd,
162 int ne,
163 int type_src0,
164 int type_src1
165) {
166 dst = (global half4*)((global char*)dst + offsetd);
167
168 // This performs better than using %.
169 uint gid = get_global_id(0);
170 uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
171
172 half4 v0, v1;
173 if (type_src0 == 1) {
174 global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
175 v0 = convert_half4(src0_f32[gid]);
176 } else {
177 global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
178 v0 = src0_f16[gid];
179 }
180
181 if (type_src1 == 1) {
182 global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
183 v1 = convert_half4(src1_f32[idx1]);
184 } else {
185 global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
186 v1 = src1_f16[idx1];
187 }
188
189 dst[gid] = v0 + v1;
190}