1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3//------------------------------------------------------------------------------
4// gelu
5//------------------------------------------------------------------------------
6#define GELU_COEF_A 0.044715f
7#define GELU_QUICK_COEF -1.702f
8#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
9#define SQRT_2_INV 0.70710678118654752440084436210484f
10
11kernel void kernel_gelu(
12 global float * src0,
13 ulong offset0,
14 global float * dst,
15 ulong offsetd
16) {
17 src0 = (global float*)((global char*)src0 + offset0);
18 dst = (global float*)((global char*)dst + offsetd);
19
20 float x = src0[get_global_id(0)];
21
22 dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
23}
24
25kernel void kernel_gelu_4(
26 global float4 * src0,
27 ulong offset0,
28 global float4 * dst,
29 ulong offsetd
30) {
31 src0 = (global float4*)((global char*)src0 + offset0);
32 dst = (global float4*)((global char*)dst + offsetd);
33
34 float4 x = src0[get_global_id(0)];
35
36 dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
37}
38
39kernel void kernel_gelu_erf(
40 global float * src0,
41 ulong offset0,
42 global float * dst,
43 ulong offsetd
44) {
45 src0 = (global float*)((global char*)src0 + offset0);
46 dst = (global float*)((global char*)dst + offsetd);
47
48 float x = src0[get_global_id(0)];
49 dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
50}
51
52kernel void kernel_gelu_erf_4(
53 global float4 * src0,
54 ulong offset0,
55 global float4 * dst,
56 ulong offsetd
57) {
58 src0 = (global float4*)((global char*)src0 + offset0);
59 dst = (global float4*)((global char*)dst + offsetd);
60
61 float4 x = src0[get_global_id(0)];
62 dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
63}
64
65kernel void kernel_gelu_quick(
66 global float * src0,
67 ulong offset0,
68 global float * dst,
69 ulong offsetd
70) {
71 src0 = (global float*)((global char*)src0 + offset0);
72 dst = (global float*)((global char*)dst + offsetd);
73
74 float x = src0[get_global_id(0)];
75 dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
76}
77
78kernel void kernel_gelu_quick_4(
79 global float4 * src0,
80 ulong offset0,
81 global float4 * dst,
82 ulong offsetd
83) {
84 src0 = (global float4*)((global char*)src0 + offset0);
85 dst = (global float4*)((global char*)dst + offsetd);
86
87 float4 x = src0[get_global_id(0)];
88 dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
89}