1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
 2
 3//------------------------------------------------------------------------------
 4// gelu
 5//------------------------------------------------------------------------------
 6#define GELU_COEF_A     0.044715f
 7#define GELU_QUICK_COEF -1.702f
 8#define SQRT_2_OVER_PI  0.79788456080286535587989211986876f
 9#define SQRT_2_INV      0.70710678118654752440084436210484f
10
11kernel void kernel_gelu(
12    global float * src0,
13    ulong offset0,
14    global float * dst,
15    ulong offsetd
16) {
17    src0 = (global float*)((global char*)src0 + offset0);
18    dst = (global float*)((global char*)dst + offsetd);
19
20    float x = src0[get_global_id(0)];
21
22    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
23}
24
25kernel void kernel_gelu_4(
26    global float4 * src0,
27    ulong offset0,
28    global float4 * dst,
29    ulong offsetd
30) {
31    src0 = (global float4*)((global char*)src0 + offset0);
32    dst = (global float4*)((global char*)dst + offsetd);
33
34    float4 x = src0[get_global_id(0)];
35
36    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
37}
38
39kernel void kernel_gelu_erf(
40    global float * src0,
41    ulong offset0,
42    global float * dst,
43    ulong offsetd
44) {
45    src0 = (global float*)((global char*)src0 + offset0);
46    dst = (global float*)((global char*)dst + offsetd);
47
48    float x = src0[get_global_id(0)];
49    dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
50}
51
52kernel void kernel_gelu_erf_4(
53    global float4 * src0,
54    ulong offset0,
55    global float4 * dst,
56    ulong offsetd
57) {
58    src0 = (global float4*)((global char*)src0 + offset0);
59    dst = (global float4*)((global char*)dst + offsetd);
60
61    float4 x = src0[get_global_id(0)];
62    dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
63}
64
65kernel void kernel_gelu_quick(
66    global float * src0,
67    ulong offset0,
68    global float * dst,
69    ulong offsetd
70) {
71    src0 = (global float*)((global char*)src0 + offset0);
72    dst = (global float*)((global char*)dst + offsetd);
73
74    float x = src0[get_global_id(0)];
75    dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
76}
77
78kernel void kernel_gelu_quick_4(
79    global float4 * src0,
80    ulong offset0,
81    global float4 * dst,
82    ulong offsetd
83) {
84    src0 = (global float4*)((global char*)src0 + offset0);
85    dst = (global float4*)((global char*)dst + offsetd);
86
87    float4 x = src0[get_global_id(0)];
88    dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
89}