diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl new file mode 100644 index 0000000..4b18d17 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -0,0 +1,190 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} + +//------------------------------------------------------------------------------ +// rms_norm_mul +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, + float eps, + local float * sum +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + // The size of sum is sizeof(float)*subgroup_size. + // Each subgroup writes its partial sum to this array. + // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size. + // This is generally true - + // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024). + if (get_sub_group_id() == 0) { + sum[get_sub_group_local_id()] = 0.0f; + } + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11); + + float sumf = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += dot(x[i00], x[i00]); + } + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + // if (get_local_id(0) < i) { + // sum[get_local_id(0)] += sum[get_local_id(0) + i]; + // } + //} + //if (get_local_id(0) == 0) { + // sum[0] /= ne00; + //} + + //barrier(CLK_LOCAL_MEM_FENCE); + + sumf = sum[get_sub_group_local_id()]; + sumf = sub_group_reduce_add(sumf); + + float mean = sumf / ne00; + float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1); + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = (x[i00] * scale) * f[i00%(ne10/4)]; + } +} |
