summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl190
1 files changed, 190 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl
new file mode 100644
index 0000000..4b18d17
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl
@@ -0,0 +1,190 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_subgroups
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+//------------------------------------------------------------------------------
+// rms_norm
+//------------------------------------------------------------------------------
+// This kernel depends on subgroup size.
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_rms_norm(
+ global void * src0,
+ ulong offset0,
+ global float * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int ne02,
+ int ne03,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ float eps,
+ local float * sum // Note, the size depends on number of subgroups
+) {
+ src0 = (global void*)((global char*)src0 + offset0);
+ dst = (global float*)((global char*)dst + offsetd);
+
+ int i03 = get_group_id(2);
+ int i02 = get_group_id(1);
+ int i01 = get_group_id(0);
+
+ global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ global float * x_scalar = (global float *) x;
+ float4 sumf = 0;
+ float all_sum = 0;
+
+ // parallel sum
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+ sumf += x[i00] * x[i00];
+ }
+ all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
+ all_sum = sub_group_reduce_add(all_sum);
+ if (get_sub_group_local_id() == 0) {
+ sum[get_sub_group_id()] = all_sum;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+ // broadcast
+ for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+ if (get_local_id(0) < i) {
+ sum[get_local_id(0)] += sum[get_local_id(0) + i];
+ }
+ }
+ if (get_local_id(0) == 0) {
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
+ sum[0] += x_scalar[i];
+ }
+ sum[0] /= ne00;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ const float mean = sum[0];
+ const float scale = 1.0f/sqrt(mean + eps);
+
+ global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ global float * y_scalar = (global float *) y;
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+ y[i00] = x[i00] * scale;
+ }
+ if (get_local_id(0) == 0) {
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
+ y_scalar[i00] = x_scalar[i00] * scale;
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+// rms_norm_mul
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_rms_norm_mul(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int ne02,
+ int ne03,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne10,
+ int ne11,
+ int ne12,
+ int ne13,
+ ulong nb11,
+ ulong nb12,
+ ulong nb13,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3,
+ float eps,
+ local float * sum
+) {
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ // The size of sum is sizeof(float)*subgroup_size.
+ // Each subgroup writes its partial sum to this array.
+ // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
+ // This is generally true -
+ // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
+ if (get_sub_group_id() == 0) {
+ sum[get_sub_group_local_id()] = 0.0f;
+ }
+
+ int i03 = get_group_id(2);
+ int i02 = get_group_id(1);
+ int i01 = get_group_id(0);
+
+ global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
+
+ float sumf = 0;
+
+ // parallel sum
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+ sumf += dot(x[i00], x[i00]);
+ }
+ sumf = sub_group_reduce_add(sumf);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if (get_sub_group_local_id() == 0) {
+ sum[get_sub_group_id()] = sumf;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+ // if (get_local_id(0) < i) {
+ // sum[get_local_id(0)] += sum[get_local_id(0) + i];
+ // }
+ //}
+ //if (get_local_id(0) == 0) {
+ // sum[0] /= ne00;
+ //}
+
+ //barrier(CLK_LOCAL_MEM_FENCE);
+
+ sumf = sum[get_sub_group_local_id()];
+ sumf = sub_group_reduce_add(sumf);
+
+ float mean = sumf / ne00;
+ float scale = 1.0f/sqrt(mean + eps);
+
+ global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+ y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
+ }
+}