summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl378
1 files changed, 378 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl
new file mode 100644
index 0000000..059a4bb
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl
@@ -0,0 +1,378 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define GELU_COEF_A 0.044715f
+#define GELU_QUICK_COEF -1.702f
+#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
+#define SQRT_2_INV 0.70710678118654752440084436210484f
+
+//------------------------------------------------------------------------------
+// geglu
+//------------------------------------------------------------------------------
+kernel void kernel_geglu(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const float x0 = src0_row[i0];
+ const float x1 = src1_row[i0];
+
+ const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
+
+ dst_row[i0] = gelu*x1;
+ }
+}
+
+kernel void kernel_geglu_f16(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const half x0 = src0_row[i0];
+ const half x1 = src1_row[i0];
+
+ const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
+
+ dst_row[i0] = gelu*x1;
+ }
+}
+
+//------------------------------------------------------------------------------
+// reglu
+//------------------------------------------------------------------------------
+kernel void kernel_reglu(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const float x0 = src0_row[i0];
+ const float x1 = src1_row[i0];
+
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
+ }
+}
+
+kernel void kernel_reglu_f16(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const half x0 = src0_row[i0];
+ const half x1 = src1_row[i0];
+
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
+ }
+}
+
+//------------------------------------------------------------------------------
+// swiglu
+//------------------------------------------------------------------------------
+kernel void kernel_swiglu(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const float x0 = src0_row[i0];
+ const float x1 = src1_row[i0];
+
+ const float silu = x0 / (1.0f + exp(-x0));
+
+ dst_row[i0] = silu*x1;
+ }
+}
+
+kernel void kernel_swiglu_f16(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const half x0 = src0_row[i0];
+ const half x1 = src1_row[i0];
+
+ const half silu = x0 / (1.0f + exp(-x0));
+
+ dst_row[i0] = silu*x1;
+ }
+}
+
+//------------------------------------------------------------------------------
+// swiglu_oai
+//------------------------------------------------------------------------------
+kernel void kernel_swiglu_oai(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off,
+ float limit,
+ float alpha
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ float x0 = src0_row[i0];
+ float x1 = src1_row[i0];
+
+ x0 = min(x0, limit);
+ x1 = max(min(x1, limit), -limit);
+
+ float out_glu = x0 / (1.0f + exp(-x0 * alpha));
+ out_glu = out_glu * (1.0f + x1);
+
+ dst_row[i0] = out_glu;
+ }
+}
+
+//------------------------------------------------------------------------------
+// geglu_erf
+//------------------------------------------------------------------------------
+kernel void kernel_geglu_erf(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const float x0 = src0_row[i0];
+ const float x1 = src1_row[i0];
+
+ const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
+
+ dst_row[i0] = gelu_erf*x1;
+ }
+}
+
+kernel void kernel_geglu_erf_f16(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const half x0 = src0_row[i0];
+ const half x1 = src1_row[i0];
+
+ const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
+
+ dst_row[i0] = gelu_erf*x1;
+ }
+}
+
+//------------------------------------------------------------------------------
+// geglu_quick
+//------------------------------------------------------------------------------
+kernel void kernel_geglu_quick(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const float x0 = src0_row[i0];
+ const float x1 = src1_row[i0];
+
+ const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
+
+ dst_row[i0] = gelu_quick*x1;
+ }
+}
+
+kernel void kernel_geglu_quick_f16(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb01,
+ ulong nb11,
+ int ne0,
+ ulong nb1,
+ int ne00_off,
+ int ne10_off
+) {
+ src0 = (global char*)((global char*)src0 + offset0);
+ src1 = (global char*)((global char*)src1 + offset1);
+ dst = (global char*)((global char*)dst + offsetd);
+
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ const half x0 = src0_row[i0];
+ const half x1 = src1_row[i0];
+
+ const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
+
+ dst_row[i0] = gelu_quick*x1;
+ }
+}