1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#define GELU_COEF_A     0.044715f
  4#define GELU_QUICK_COEF -1.702f
  5#define SQRT_2_OVER_PI  0.79788456080286535587989211986876f
  6#define SQRT_2_INV      0.70710678118654752440084436210484f
  7
  8//------------------------------------------------------------------------------
  9// geglu
 10//------------------------------------------------------------------------------
 11kernel void kernel_geglu(
 12    global char * src0,
 13    ulong  offset0,
 14    global char * src1,
 15    ulong  offset1,
 16    global char * dst,
 17    ulong  offsetd,
 18    ulong nb01,
 19    ulong nb11,
 20    int ne0,
 21    ulong nb1,
 22    int ne00_off,
 23    int ne10_off
 24) {
 25    src0 = (global char*)((global char*)src0 + offset0);
 26    src1 = (global char*)((global char*)src1 + offset1);
 27    dst  = (global char*)((global char*)dst  + offsetd);
 28
 29    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
 30    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
 31    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
 32
 33    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
 34        const float x0 = src0_row[i0];
 35        const float x1 = src1_row[i0];
 36
 37        const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
 38
 39        dst_row[i0] = gelu*x1;
 40    }
 41}
 42
 43kernel void kernel_geglu_f16(
 44    global char * src0,
 45    ulong  offset0,
 46    global char * src1,
 47    ulong  offset1,
 48    global char * dst,
 49    ulong  offsetd,
 50    ulong nb01,
 51    ulong nb11,
 52    int ne0,
 53    ulong nb1,
 54    int ne00_off,
 55    int ne10_off
 56) {
 57    src0 = (global char*)((global char*)src0 + offset0);
 58    src1 = (global char*)((global char*)src1 + offset1);
 59    dst  = (global char*)((global char*)dst  + offsetd);
 60
 61    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
 62    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
 63    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
 64
 65    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
 66        const half x0 = src0_row[i0];
 67        const half x1 = src1_row[i0];
 68
 69        const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
 70
 71        dst_row[i0] = gelu*x1;
 72    }
 73}
 74
 75//------------------------------------------------------------------------------
 76// reglu
 77//------------------------------------------------------------------------------
 78kernel void kernel_reglu(
 79    global char * src0,
 80    ulong  offset0,
 81    global char * src1,
 82    ulong  offset1,
 83    global char * dst,
 84    ulong  offsetd,
 85    ulong nb01,
 86    ulong nb11,
 87    int ne0,
 88    ulong nb1,
 89    int ne00_off,
 90    int ne10_off
 91) {
 92    src0 = (global char*)((global char*)src0 + offset0);
 93    src1 = (global char*)((global char*)src1 + offset1);
 94    dst  = (global char*)((global char*)dst  + offsetd);
 95
 96    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
 97    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
 98    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
 99
100    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
101        const float x0 = src0_row[i0];
102        const float x1 = src1_row[i0];
103
104        dst_row[i0] = x0*x1*(x0 > 0.0f);
105    }
106}
107
108kernel void kernel_reglu_f16(
109    global char * src0,
110    ulong  offset0,
111    global char * src1,
112    ulong  offset1,
113    global char * dst,
114    ulong  offsetd,
115    ulong nb01,
116    ulong nb11,
117    int ne0,
118    ulong nb1,
119    int ne00_off,
120    int ne10_off
121) {
122    src0 = (global char*)((global char*)src0 + offset0);
123    src1 = (global char*)((global char*)src1 + offset1);
124    dst  = (global char*)((global char*)dst  + offsetd);
125
126    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
127    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
128    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
129
130    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
131        const half x0 = src0_row[i0];
132        const half x1 = src1_row[i0];
133
134        dst_row[i0] = x0*x1*(x0 > 0.0f);
135    }
136}
137
138//------------------------------------------------------------------------------
139// swiglu
140//------------------------------------------------------------------------------
141kernel void kernel_swiglu(
142    global char * src0,
143    ulong  offset0,
144    global char * src1,
145    ulong  offset1,
146    global char * dst,
147    ulong  offsetd,
148    ulong nb01,
149    ulong nb11,
150    int ne0,
151    ulong nb1,
152    int ne00_off,
153    int ne10_off
154) {
155    src0 = (global char*)((global char*)src0 + offset0);
156    src1 = (global char*)((global char*)src1 + offset1);
157    dst  = (global char*)((global char*)dst  + offsetd);
158
159    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
160    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
161    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
162
163    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
164        const float x0 = src0_row[i0];
165        const float x1 = src1_row[i0];
166
167        const float silu = x0 / (1.0f + exp(-x0));
168
169        dst_row[i0] = silu*x1;
170    }
171}
172
173kernel void kernel_swiglu_f16(
174    global char * src0,
175    ulong  offset0,
176    global char * src1,
177    ulong  offset1,
178    global char * dst,
179    ulong  offsetd,
180    ulong nb01,
181    ulong nb11,
182    int ne0,
183    ulong nb1,
184    int ne00_off,
185    int ne10_off
186) {
187    src0 = (global char*)((global char*)src0 + offset0);
188    src1 = (global char*)((global char*)src1 + offset1);
189    dst  = (global char*)((global char*)dst  + offsetd);
190
191    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
192    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
193    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
194
195    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
196        const half x0 = src0_row[i0];
197        const half x1 = src1_row[i0];
198
199        const half silu = x0 / (1.0f + exp(-x0));
200
201        dst_row[i0] = silu*x1;
202    }
203}
204
205//------------------------------------------------------------------------------
206// swiglu_oai
207//------------------------------------------------------------------------------
208kernel void kernel_swiglu_oai(
209    global char * src0,
210    ulong         offset0,
211    global char * src1,
212    ulong         offset1,
213    global char * dst,
214    ulong         offsetd,
215    ulong         nb01,
216    ulong         nb11,
217    int           ne0,
218    ulong         nb1,
219    int           ne00_off,
220    int           ne10_off,
221    float         limit,
222    float         alpha
223) {
224    src0 = (global char*)((global char*)src0 + offset0);
225    src1 = (global char*)((global char*)src1 + offset1);
226    dst  = (global char*)((global char*)dst  + offsetd);
227
228    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
229    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
230    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
231
232    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
233        float x0 = src0_row[i0];
234        float x1 = src1_row[i0];
235
236        x0 = min(x0, limit);
237        x1 = max(min(x1, limit), -limit);
238
239        float out_glu = x0 / (1.0f + exp(-x0 * alpha));
240        out_glu = out_glu * (1.0f + x1);
241
242        dst_row[i0] = out_glu;
243    }
244}
245
246//------------------------------------------------------------------------------
247// geglu_erf
248//------------------------------------------------------------------------------
249kernel void kernel_geglu_erf(
250    global char * src0,
251    ulong  offset0,
252    global char * src1,
253    ulong  offset1,
254    global char * dst,
255    ulong  offsetd,
256    ulong nb01,
257    ulong nb11,
258    int ne0,
259    ulong nb1,
260    int ne00_off,
261    int ne10_off
262) {
263    src0 = (global char*)((global char*)src0 + offset0);
264    src1 = (global char*)((global char*)src1 + offset1);
265    dst  = (global char*)((global char*)dst  + offsetd);
266
267    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
268    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
269    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
270
271    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
272        const float x0 = src0_row[i0];
273        const float x1 = src1_row[i0];
274
275        const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
276
277        dst_row[i0] = gelu_erf*x1;
278    }
279}
280
281kernel void kernel_geglu_erf_f16(
282    global char * src0,
283    ulong  offset0,
284    global char * src1,
285    ulong  offset1,
286    global char * dst,
287    ulong  offsetd,
288    ulong nb01,
289    ulong nb11,
290    int ne0,
291    ulong nb1,
292    int ne00_off,
293    int ne10_off
294) {
295    src0 = (global char*)((global char*)src0 + offset0);
296    src1 = (global char*)((global char*)src1 + offset1);
297    dst  = (global char*)((global char*)dst  + offsetd);
298
299    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
300    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
301    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
302
303    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
304        const half x0 = src0_row[i0];
305        const half x1 = src1_row[i0];
306
307        const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
308
309        dst_row[i0] = gelu_erf*x1;
310    }
311}
312
313//------------------------------------------------------------------------------
314// geglu_quick
315//------------------------------------------------------------------------------
316kernel void kernel_geglu_quick(
317    global char * src0,
318    ulong  offset0,
319    global char * src1,
320    ulong  offset1,
321    global char * dst,
322    ulong  offsetd,
323    ulong nb01,
324    ulong nb11,
325    int ne0,
326    ulong nb1,
327    int ne00_off,
328    int ne10_off
329) {
330    src0 = (global char*)((global char*)src0 + offset0);
331    src1 = (global char*)((global char*)src1 + offset1);
332    dst  = (global char*)((global char*)dst  + offsetd);
333
334    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
335    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
336    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
337
338    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
339        const float x0 = src0_row[i0];
340        const float x1 = src1_row[i0];
341
342        const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
343
344        dst_row[i0] = gelu_quick*x1;
345    }
346}
347
348kernel void kernel_geglu_quick_f16(
349    global char * src0,
350    ulong  offset0,
351    global char * src1,
352    ulong  offset1,
353    global char * dst,
354    ulong  offsetd,
355    ulong nb01,
356    ulong nb11,
357    int ne0,
358    ulong nb1,
359    int ne00_off,
360    int ne10_off
361) {
362    src0 = (global char*)((global char*)src0 + offset0);
363    src1 = (global char*)((global char*)src1 + offset1);
364    dst  = (global char*)((global char*)dst  + offsetd);
365
366    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
367    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
368    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
369
370    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
371        const half x0 = src0_row[i0];
372        const half x1 = src1_row[i0];
373
374        const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
375
376        dst_row[i0] = gelu_quick*x1;
377    }
378}