1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_required_subgroup_size
4#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
5#define INTEL_GPU 1
6#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
7#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
8#elif defined(cl_qcom_reqd_sub_group_size)
9#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
10#define ADRENO_GPU 1
11#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
12#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
13#endif
14
15//------------------------------------------------------------------------------
16// norm
17//------------------------------------------------------------------------------
18kernel void kernel_norm(
19 global void * src0,
20 ulong offset0,
21 global float * dst,
22 ulong offsetd,
23 int ne00,
24 int ne01,
25 int ne02,
26 int ne03,
27 ulong nb01,
28 ulong nb02,
29 ulong nb03,
30 float eps,
31 local float * sum
32) {
33 src0 = (global void*)((global char*)src0 + offset0);
34 dst = (global void*)((global char*)dst + offsetd);
35
36 int i03 = get_group_id(2);
37 int i02 = get_group_id(1);
38 int i01 = get_group_id(0);
39
40 global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
41
42 // MEAN
43 // parallel sum
44 sum[get_local_id(0)] = 0.0f;
45 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
46 sum[get_local_id(0)] += x[i00];
47 }
48 // reduce
49 barrier(CLK_LOCAL_MEM_FENCE);
50 for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
51 if (get_local_id(0) < i) {
52 sum[get_local_id(0)] += sum[get_local_id(0) + i];
53 }
54 barrier(CLK_LOCAL_MEM_FENCE);
55 }
56 float mean = sum[0] / ne00;
57
58 // recenter and VARIANCE
59 barrier(CLK_LOCAL_MEM_FENCE);
60 global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
61 sum[get_local_id(0)] = 0.0f;
62 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
63 y[i00] = x[i00] - mean;
64 sum[get_local_id(0)] += y[i00] * y[i00];
65 }
66
67 // reduce
68 barrier(CLK_LOCAL_MEM_FENCE);
69 for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
70 if (get_local_id(0) < i) {
71 sum[get_local_id(0)] += sum[get_local_id(0) + i];
72 }
73 barrier(CLK_LOCAL_MEM_FENCE);
74 }
75 float variance = sum[0] / ne00;
76
77 float scale = 1.0f/sqrt(variance + eps);
78 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
79 y[i00] = y[i00] * scale;
80 }
81}
82
83//------------------------------------------------------------------------------
84// norm_mul_add
85//------------------------------------------------------------------------------
86#ifdef INTEL_GPU
87REQD_SUBGROUP_SIZE_32
88#elif defined (ADRENO_GPU)
89REQD_SUBGROUP_SIZE_64
90#endif
91kernel void kernel_norm_mul_add(
92 global char * src0_ptr, ulong src0_offset,
93 global char * src1_ptr, ulong src1_offset,
94 global char * src2_ptr, ulong src2_offset,
95 global char * dst_ptr, ulong dst_offset,
96 int ne00, int ne01, int ne02, int ne03,
97 ulong nb01, ulong nb02, ulong nb03,
98 int ne10, int ne11, int ne12, int ne13,
99 ulong nb11, ulong nb12, ulong nb13,
100 int ne20, int ne21, int ne22, int ne23,
101 ulong nb21, ulong nb22, ulong nb23,
102 ulong nbd1, ulong nbd2, ulong nbd3,
103 float eps,
104 local float2 * sums
105) {
106 const int i03 = get_group_id(2);
107 const int i02 = get_group_id(1);
108 const int i01 = get_group_id(0);
109
110 global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
111 global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
112 global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
113 global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
114
115 float p_sum = 0.0f;
116 float p_sum_sq = 0.0f;
117
118 const int n_chunks = ne00 / 4;
119 for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
120 float4 val = x[i00];
121 p_sum += val.x + val.y + val.z + val.w;
122 p_sum_sq += dot(val, val);
123 }
124
125 p_sum = sub_group_reduce_add(p_sum);
126 p_sum_sq = sub_group_reduce_add(p_sum_sq);
127
128 if (get_sub_group_local_id() == 0) {
129 sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
130 }
131 barrier(CLK_LOCAL_MEM_FENCE);
132
133 if (get_local_id(0) == 0) {
134 float sum = 0.0f;
135 float sum_sq = 0.0f;
136 for (uint i = 0; i < get_num_sub_groups(); ++i) {
137 float2 s = sums[i];
138 sum += s.x;
139 sum_sq += s.y;
140 }
141
142 const float inv_ne00 = 1.0f / (float)ne00;
143 const float mean = sum * inv_ne00;
144 const float variance = mad(-mean, mean, sum_sq * inv_ne00);
145
146 sums[0] = (float2)(mean, rsqrt(variance + eps));
147 }
148 barrier(CLK_LOCAL_MEM_FENCE);
149
150 const float2 mean_scale = sums[0];
151 const float mean = mean_scale.x;
152 const float scale = mean_scale.y;
153 const float neg_mean_scale = -mean * scale;
154
155 for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
156 const int w_idx = ne10 > 1 ? i00 : 0;
157 const int b_idx = ne20 > 1 ? i00 : 0;
158 const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
159 y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
160 }
161}