1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#define GELU_COEF_A 0.044715f
4#define GELU_QUICK_COEF -1.702f
5#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
6#define SQRT_2_INV 0.70710678118654752440084436210484f
7
8//------------------------------------------------------------------------------
9// geglu
10//------------------------------------------------------------------------------
11kernel void kernel_geglu(
12 global char * src0,
13 ulong offset0,
14 global char * src1,
15 ulong offset1,
16 global char * dst,
17 ulong offsetd,
18 ulong nb01,
19 ulong nb11,
20 int ne0,
21 ulong nb1,
22 int ne00_off,
23 int ne10_off
24) {
25 src0 = (global char*)((global char*)src0 + offset0);
26 src1 = (global char*)((global char*)src1 + offset1);
27 dst = (global char*)((global char*)dst + offsetd);
28
29 global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
30 global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
31 global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
32
33 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
34 const float x0 = src0_row[i0];
35 const float x1 = src1_row[i0];
36
37 const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
38
39 dst_row[i0] = gelu*x1;
40 }
41}
42
43kernel void kernel_geglu_f16(
44 global char * src0,
45 ulong offset0,
46 global char * src1,
47 ulong offset1,
48 global char * dst,
49 ulong offsetd,
50 ulong nb01,
51 ulong nb11,
52 int ne0,
53 ulong nb1,
54 int ne00_off,
55 int ne10_off
56) {
57 src0 = (global char*)((global char*)src0 + offset0);
58 src1 = (global char*)((global char*)src1 + offset1);
59 dst = (global char*)((global char*)dst + offsetd);
60
61 global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
62 global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
63 global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
64
65 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
66 const half x0 = src0_row[i0];
67 const half x1 = src1_row[i0];
68
69 const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
70
71 dst_row[i0] = gelu*x1;
72 }
73}
74
75//------------------------------------------------------------------------------
76// reglu
77//------------------------------------------------------------------------------
78kernel void kernel_reglu(
79 global char * src0,
80 ulong offset0,
81 global char * src1,
82 ulong offset1,
83 global char * dst,
84 ulong offsetd,
85 ulong nb01,
86 ulong nb11,
87 int ne0,
88 ulong nb1,
89 int ne00_off,
90 int ne10_off
91) {
92 src0 = (global char*)((global char*)src0 + offset0);
93 src1 = (global char*)((global char*)src1 + offset1);
94 dst = (global char*)((global char*)dst + offsetd);
95
96 global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
97 global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
98 global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
99
100 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
101 const float x0 = src0_row[i0];
102 const float x1 = src1_row[i0];
103
104 dst_row[i0] = x0*x1*(x0 > 0.0f);
105 }
106}
107
108kernel void kernel_reglu_f16(
109 global char * src0,
110 ulong offset0,
111 global char * src1,
112 ulong offset1,
113 global char * dst,
114 ulong offsetd,
115 ulong nb01,
116 ulong nb11,
117 int ne0,
118 ulong nb1,
119 int ne00_off,
120 int ne10_off
121) {
122 src0 = (global char*)((global char*)src0 + offset0);
123 src1 = (global char*)((global char*)src1 + offset1);
124 dst = (global char*)((global char*)dst + offsetd);
125
126 global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
127 global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
128 global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
129
130 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
131 const half x0 = src0_row[i0];
132 const half x1 = src1_row[i0];
133
134 dst_row[i0] = x0*x1*(x0 > 0.0f);
135 }
136}
137
138//------------------------------------------------------------------------------
139// swiglu
140//------------------------------------------------------------------------------
141kernel void kernel_swiglu(
142 global char * src0,
143 ulong offset0,
144 global char * src1,
145 ulong offset1,
146 global char * dst,
147 ulong offsetd,
148 ulong nb01,
149 ulong nb11,
150 int ne0,
151 ulong nb1,
152 int ne00_off,
153 int ne10_off
154) {
155 src0 = (global char*)((global char*)src0 + offset0);
156 src1 = (global char*)((global char*)src1 + offset1);
157 dst = (global char*)((global char*)dst + offsetd);
158
159 global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
160 global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
161 global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
162
163 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
164 const float x0 = src0_row[i0];
165 const float x1 = src1_row[i0];
166
167 const float silu = x0 / (1.0f + exp(-x0));
168
169 dst_row[i0] = silu*x1;
170 }
171}
172
173kernel void kernel_swiglu_f16(
174 global char * src0,
175 ulong offset0,
176 global char * src1,
177 ulong offset1,
178 global char * dst,
179 ulong offsetd,
180 ulong nb01,
181 ulong nb11,
182 int ne0,
183 ulong nb1,
184 int ne00_off,
185 int ne10_off
186) {
187 src0 = (global char*)((global char*)src0 + offset0);
188 src1 = (global char*)((global char*)src1 + offset1);
189 dst = (global char*)((global char*)dst + offsetd);
190
191 global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
192 global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
193 global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
194
195 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
196 const half x0 = src0_row[i0];
197 const half x1 = src1_row[i0];
198
199 const half silu = x0 / (1.0f + exp(-x0));
200
201 dst_row[i0] = silu*x1;
202 }
203}
204
205//------------------------------------------------------------------------------
206// swiglu_oai
207//------------------------------------------------------------------------------
208kernel void kernel_swiglu_oai(
209 global char * src0,
210 ulong offset0,
211 global char * src1,
212 ulong offset1,
213 global char * dst,
214 ulong offsetd,
215 ulong nb01,
216 ulong nb11,
217 int ne0,
218 ulong nb1,
219 int ne00_off,
220 int ne10_off,
221 float limit,
222 float alpha
223) {
224 src0 = (global char*)((global char*)src0 + offset0);
225 src1 = (global char*)((global char*)src1 + offset1);
226 dst = (global char*)((global char*)dst + offsetd);
227
228 global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
229 global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
230 global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
231
232 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
233 float x0 = src0_row[i0];
234 float x1 = src1_row[i0];
235
236 x0 = min(x0, limit);
237 x1 = max(min(x1, limit), -limit);
238
239 float out_glu = x0 / (1.0f + exp(-x0 * alpha));
240 out_glu = out_glu * (1.0f + x1);
241
242 dst_row[i0] = out_glu;
243 }
244}
245
246//------------------------------------------------------------------------------
247// geglu_erf
248//------------------------------------------------------------------------------
249kernel void kernel_geglu_erf(
250 global char * src0,
251 ulong offset0,
252 global char * src1,
253 ulong offset1,
254 global char * dst,
255 ulong offsetd,
256 ulong nb01,
257 ulong nb11,
258 int ne0,
259 ulong nb1,
260 int ne00_off,
261 int ne10_off
262) {
263 src0 = (global char*)((global char*)src0 + offset0);
264 src1 = (global char*)((global char*)src1 + offset1);
265 dst = (global char*)((global char*)dst + offsetd);
266
267 global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
268 global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
269 global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
270
271 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
272 const float x0 = src0_row[i0];
273 const float x1 = src1_row[i0];
274
275 const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
276
277 dst_row[i0] = gelu_erf*x1;
278 }
279}
280
281kernel void kernel_geglu_erf_f16(
282 global char * src0,
283 ulong offset0,
284 global char * src1,
285 ulong offset1,
286 global char * dst,
287 ulong offsetd,
288 ulong nb01,
289 ulong nb11,
290 int ne0,
291 ulong nb1,
292 int ne00_off,
293 int ne10_off
294) {
295 src0 = (global char*)((global char*)src0 + offset0);
296 src1 = (global char*)((global char*)src1 + offset1);
297 dst = (global char*)((global char*)dst + offsetd);
298
299 global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
300 global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
301 global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
302
303 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
304 const half x0 = src0_row[i0];
305 const half x1 = src1_row[i0];
306
307 const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
308
309 dst_row[i0] = gelu_erf*x1;
310 }
311}
312
313//------------------------------------------------------------------------------
314// geglu_quick
315//------------------------------------------------------------------------------
316kernel void kernel_geglu_quick(
317 global char * src0,
318 ulong offset0,
319 global char * src1,
320 ulong offset1,
321 global char * dst,
322 ulong offsetd,
323 ulong nb01,
324 ulong nb11,
325 int ne0,
326 ulong nb1,
327 int ne00_off,
328 int ne10_off
329) {
330 src0 = (global char*)((global char*)src0 + offset0);
331 src1 = (global char*)((global char*)src1 + offset1);
332 dst = (global char*)((global char*)dst + offsetd);
333
334 global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
335 global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
336 global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
337
338 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
339 const float x0 = src0_row[i0];
340 const float x1 = src1_row[i0];
341
342 const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
343
344 dst_row[i0] = gelu_quick*x1;
345 }
346}
347
348kernel void kernel_geglu_quick_f16(
349 global char * src0,
350 ulong offset0,
351 global char * src1,
352 ulong offset1,
353 global char * dst,
354 ulong offsetd,
355 ulong nb01,
356 ulong nb11,
357 int ne0,
358 ulong nb1,
359 int ne00_off,
360 int ne10_off
361) {
362 src0 = (global char*)((global char*)src0 + offset0);
363 src1 = (global char*)((global char*)src1 + offset1);
364 dst = (global char*)((global char*)dst + offsetd);
365
366 global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
367 global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
368 global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
369
370 for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
371 const half x0 = src0_row[i0];
372 const half x1 = src1_row[i0];
373
374 const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
375
376 dst_row[i0] = gelu_quick*x1;
377 }
378}