1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21//------------------------------------------------------------------------------
 22// rms_norm
 23//------------------------------------------------------------------------------
 24// This kernel depends on subgroup size.
 25#ifdef INTEL_GPU
 26REQD_SUBGROUP_SIZE_32
 27#elif defined (ADRENO_GPU)
 28REQD_SUBGROUP_SIZE_64
 29#endif
 30kernel void kernel_rms_norm(
 31        global void * src0,
 32        ulong offset0,
 33        global float * dst,
 34        ulong offsetd,
 35        int ne00,
 36        int ne01,
 37        int ne02,
 38        int ne03,
 39        ulong nb01,
 40        ulong nb02,
 41        ulong nb03,
 42        float eps,
 43        local float * sum // Note, the size depends on number of subgroups
 44) {
 45    src0 = (global void*)((global char*)src0 + offset0);
 46    dst = (global float*)((global char*)dst + offsetd);
 47
 48    int i03 = get_group_id(2);
 49    int i02 = get_group_id(1);
 50    int i01 = get_group_id(0);
 51
 52    global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
 53    global float * x_scalar = (global float *) x;
 54    float4 sumf = 0;
 55    float all_sum = 0;
 56
 57    // parallel sum
 58    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
 59        sumf += x[i00] * x[i00];
 60    }
 61    all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
 62    all_sum = sub_group_reduce_add(all_sum);
 63    if (get_sub_group_local_id() == 0) {
 64        sum[get_sub_group_id()] = all_sum;
 65    }
 66
 67    barrier(CLK_LOCAL_MEM_FENCE);
 68    // broadcast
 69    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
 70       if (get_local_id(0) < i) {
 71           sum[get_local_id(0)] += sum[get_local_id(0) + i];
 72       }
 73    }
 74    if (get_local_id(0) == 0) {
 75        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
 76            sum[0] += x_scalar[i];
 77        }
 78        sum[0] /= ne00;
 79    }
 80
 81    barrier(CLK_LOCAL_MEM_FENCE);
 82
 83    const float mean  = sum[0];
 84    const float scale = 1.0f/sqrt(mean + eps);
 85
 86    global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 87    global float * y_scalar = (global float *) y;
 88    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
 89        y[i00] = x[i00] * scale;
 90    }
 91    if (get_local_id(0) == 0) {
 92        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
 93            y_scalar[i00] = x_scalar[i00] * scale;
 94        }
 95    }
 96}
 97
 98//------------------------------------------------------------------------------
 99// rms_norm_mul
100//------------------------------------------------------------------------------
101#ifdef INTEL_GPU
102REQD_SUBGROUP_SIZE_32
103#elif defined (ADRENO_GPU)
104REQD_SUBGROUP_SIZE_64
105#endif
106kernel void kernel_rms_norm_mul(
107        global char * src0,
108        ulong offset0,
109        global char * src1,
110        ulong offset1,
111        global char * dst,
112        ulong offsetd,
113        int ne00,
114        int ne01,
115        int ne02,
116        int ne03,
117        ulong nb01,
118        ulong nb02,
119        ulong nb03,
120        int ne10,
121        int ne11,
122        int ne12,
123        int ne13,
124        ulong nb11,
125        ulong nb12,
126        ulong nb13,
127        ulong nb1,
128        ulong nb2,
129        ulong nb3,
130        float eps,
131        local float * sum
132) {
133    src0 = src0 + offset0;
134    src1 = src1 + offset1;
135    dst  = dst  + offsetd;
136
137    // The size of sum is sizeof(float)*subgroup_size.
138    // Each subgroup writes its partial sum to this array.
139    // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
140    // This is generally true -
141    // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
142    if (get_sub_group_id() == 0) {
143        sum[get_sub_group_local_id()] = 0.0f;
144    }
145
146    int i03 = get_group_id(2);
147    int i02 = get_group_id(1);
148    int i01 = get_group_id(0);
149
150    global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
151    global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
152
153    float sumf = 0;
154
155    // parallel sum
156    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
157        sumf += dot(x[i00], x[i00]);
158    }
159    sumf = sub_group_reduce_add(sumf);
160
161    barrier(CLK_LOCAL_MEM_FENCE);
162
163    if (get_sub_group_local_id() == 0) {
164        sum[get_sub_group_id()] = sumf;
165    }
166
167    barrier(CLK_LOCAL_MEM_FENCE);
168
169    //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
170    //   if (get_local_id(0) < i) {
171    //       sum[get_local_id(0)] += sum[get_local_id(0) + i];
172    //   }
173    //}
174    //if (get_local_id(0) == 0) {
175    //    sum[0] /= ne00;
176    //}
177
178    //barrier(CLK_LOCAL_MEM_FENCE);
179
180    sumf = sum[get_sub_group_local_id()];
181    sumf = sub_group_reduce_add(sumf);
182
183    float mean  = sumf / ne00;
184    float scale = 1.0f/sqrt(mean + eps);
185
186    global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
187    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
188        y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
189    }
190}