1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_required_subgroup_size
  4#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
  5#define INTEL_GPU 1
  6#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
  7#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
  8#elif defined(cl_qcom_reqd_sub_group_size)
  9#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 10#define ADRENO_GPU 1
 11#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 12#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 13#endif
 14
 15//------------------------------------------------------------------------------
 16// norm
 17//------------------------------------------------------------------------------
 18kernel void kernel_norm(
 19        global void * src0,
 20        ulong offset0,
 21        global float * dst,
 22        ulong offsetd,
 23        int ne00,
 24        int ne01,
 25        int ne02,
 26        int ne03,
 27        ulong nb01,
 28        ulong nb02,
 29        ulong nb03,
 30        float eps,
 31        local float * sum
 32) {
 33    src0 = (global void*)((global char*)src0 + offset0);
 34    dst = (global void*)((global char*)dst + offsetd);
 35
 36    int i03 = get_group_id(2);
 37    int i02 = get_group_id(1);
 38    int i01 = get_group_id(0);
 39
 40    global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
 41
 42    // MEAN
 43    // parallel sum
 44    sum[get_local_id(0)] = 0.0f;
 45    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 46        sum[get_local_id(0)] += x[i00];
 47    }
 48    // reduce
 49    barrier(CLK_LOCAL_MEM_FENCE);
 50    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
 51        if (get_local_id(0) < i) {
 52            sum[get_local_id(0)] += sum[get_local_id(0) + i];
 53        }
 54        barrier(CLK_LOCAL_MEM_FENCE);
 55    }
 56    float mean  = sum[0] / ne00;
 57
 58    // recenter and VARIANCE
 59    barrier(CLK_LOCAL_MEM_FENCE);
 60    global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 61    sum[get_local_id(0)] = 0.0f;
 62    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 63        y[i00] = x[i00] - mean;
 64        sum[get_local_id(0)] += y[i00] * y[i00];
 65    }
 66
 67    // reduce
 68    barrier(CLK_LOCAL_MEM_FENCE);
 69    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
 70        if (get_local_id(0) < i) {
 71            sum[get_local_id(0)] += sum[get_local_id(0) + i];
 72        }
 73        barrier(CLK_LOCAL_MEM_FENCE);
 74    }
 75    float variance = sum[0] / ne00;
 76
 77    float scale = 1.0f/sqrt(variance + eps);
 78    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 79        y[i00] = y[i00] * scale;
 80    }
 81}
 82
 83//------------------------------------------------------------------------------
 84// norm_mul_add
 85//------------------------------------------------------------------------------
 86#ifdef INTEL_GPU
 87REQD_SUBGROUP_SIZE_32
 88#elif defined (ADRENO_GPU)
 89REQD_SUBGROUP_SIZE_64
 90#endif
 91kernel void kernel_norm_mul_add(
 92        global char * src0_ptr, ulong src0_offset,
 93        global char * src1_ptr, ulong src1_offset,
 94        global char * src2_ptr, ulong src2_offset,
 95        global char * dst_ptr,  ulong dst_offset,
 96        int ne00, int ne01, int ne02, int ne03,
 97        ulong nb01, ulong nb02, ulong nb03,
 98        int ne10, int ne11, int ne12, int ne13,
 99        ulong nb11, ulong nb12, ulong nb13,
100        int ne20, int ne21, int ne22, int ne23,
101        ulong nb21, ulong nb22, ulong nb23,
102        ulong nbd1, ulong nbd2, ulong nbd3,
103        float eps,
104        local float2 * sums
105) {
106    const int i03 = get_group_id(2);
107    const int i02 = get_group_id(1);
108    const int i01 = get_group_id(0);
109
110    global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
111    global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
112    global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
113    global float4 * y = (global float4 *)(dst_ptr  + dst_offset  + i01*nbd1 + i02*nbd2 + i03*nbd3);
114
115    float p_sum = 0.0f;
116    float p_sum_sq = 0.0f;
117
118    const int n_chunks = ne00 / 4;
119    for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
120        float4 val = x[i00];
121        p_sum += val.x + val.y + val.z + val.w;
122        p_sum_sq += dot(val, val);
123    }
124
125    p_sum = sub_group_reduce_add(p_sum);
126    p_sum_sq = sub_group_reduce_add(p_sum_sq);
127
128    if (get_sub_group_local_id() == 0) {
129        sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
130    }
131    barrier(CLK_LOCAL_MEM_FENCE);
132
133    if (get_local_id(0) == 0) {
134        float sum = 0.0f;
135        float sum_sq = 0.0f;
136        for (uint i = 0; i < get_num_sub_groups(); ++i) {
137            float2 s = sums[i];
138            sum += s.x;
139            sum_sq += s.y;
140        }
141
142        const float inv_ne00 = 1.0f / (float)ne00;
143        const float mean = sum * inv_ne00;
144        const float variance = mad(-mean, mean, sum_sq * inv_ne00);
145
146        sums[0] = (float2)(mean, rsqrt(variance + eps));
147    }
148    barrier(CLK_LOCAL_MEM_FENCE);
149
150    const float2 mean_scale = sums[0];
151    const float mean = mean_scale.x;
152    const float scale = mean_scale.y;
153    const float neg_mean_scale = -mean * scale;
154
155    for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
156        const int w_idx = ne10 > 1 ? i00 : 0;
157        const int b_idx = ne20 > 1 ? i00 : 0;
158        const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
159        y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
160    }
161}