1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21//------------------------------------------------------------------------------
22// rms_norm
23//------------------------------------------------------------------------------
24// This kernel depends on subgroup size.
25#ifdef INTEL_GPU
26REQD_SUBGROUP_SIZE_32
27#elif defined (ADRENO_GPU)
28REQD_SUBGROUP_SIZE_64
29#endif
30kernel void kernel_rms_norm(
31 global void * src0,
32 ulong offset0,
33 global float * dst,
34 ulong offsetd,
35 int ne00,
36 int ne01,
37 int ne02,
38 int ne03,
39 ulong nb01,
40 ulong nb02,
41 ulong nb03,
42 float eps,
43 local float * sum // Note, the size depends on number of subgroups
44) {
45 src0 = (global void*)((global char*)src0 + offset0);
46 dst = (global float*)((global char*)dst + offsetd);
47
48 int i03 = get_group_id(2);
49 int i02 = get_group_id(1);
50 int i01 = get_group_id(0);
51
52 global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
53 global float * x_scalar = (global float *) x;
54 float4 sumf = 0;
55 float all_sum = 0;
56
57 // parallel sum
58 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
59 sumf += x[i00] * x[i00];
60 }
61 all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
62 all_sum = sub_group_reduce_add(all_sum);
63 if (get_sub_group_local_id() == 0) {
64 sum[get_sub_group_id()] = all_sum;
65 }
66
67 barrier(CLK_LOCAL_MEM_FENCE);
68 // broadcast
69 for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
70 if (get_local_id(0) < i) {
71 sum[get_local_id(0)] += sum[get_local_id(0) + i];
72 }
73 }
74 if (get_local_id(0) == 0) {
75 for (int i = 4 * (ne00 / 4); i < ne00; i++) {
76 sum[0] += x_scalar[i];
77 }
78 sum[0] /= ne00;
79 }
80
81 barrier(CLK_LOCAL_MEM_FENCE);
82
83 const float mean = sum[0];
84 const float scale = 1.0f/sqrt(mean + eps);
85
86 global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
87 global float * y_scalar = (global float *) y;
88 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
89 y[i00] = x[i00] * scale;
90 }
91 if (get_local_id(0) == 0) {
92 for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
93 y_scalar[i00] = x_scalar[i00] * scale;
94 }
95 }
96}
97
98//------------------------------------------------------------------------------
99// rms_norm_mul
100//------------------------------------------------------------------------------
101#ifdef INTEL_GPU
102REQD_SUBGROUP_SIZE_32
103#elif defined (ADRENO_GPU)
104REQD_SUBGROUP_SIZE_64
105#endif
106kernel void kernel_rms_norm_mul(
107 global char * src0,
108 ulong offset0,
109 global char * src1,
110 ulong offset1,
111 global char * dst,
112 ulong offsetd,
113 int ne00,
114 int ne01,
115 int ne02,
116 int ne03,
117 ulong nb01,
118 ulong nb02,
119 ulong nb03,
120 int ne10,
121 int ne11,
122 int ne12,
123 int ne13,
124 ulong nb11,
125 ulong nb12,
126 ulong nb13,
127 ulong nb1,
128 ulong nb2,
129 ulong nb3,
130 float eps,
131 local float * sum
132) {
133 src0 = src0 + offset0;
134 src1 = src1 + offset1;
135 dst = dst + offsetd;
136
137 // The size of sum is sizeof(float)*subgroup_size.
138 // Each subgroup writes its partial sum to this array.
139 // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
140 // This is generally true -
141 // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
142 if (get_sub_group_id() == 0) {
143 sum[get_sub_group_local_id()] = 0.0f;
144 }
145
146 int i03 = get_group_id(2);
147 int i02 = get_group_id(1);
148 int i01 = get_group_id(0);
149
150 global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
151 global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
152
153 float sumf = 0;
154
155 // parallel sum
156 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
157 sumf += dot(x[i00], x[i00]);
158 }
159 sumf = sub_group_reduce_add(sumf);
160
161 barrier(CLK_LOCAL_MEM_FENCE);
162
163 if (get_sub_group_local_id() == 0) {
164 sum[get_sub_group_id()] = sumf;
165 }
166
167 barrier(CLK_LOCAL_MEM_FENCE);
168
169 //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
170 // if (get_local_id(0) < i) {
171 // sum[get_local_id(0)] += sum[get_local_id(0) + i];
172 // }
173 //}
174 //if (get_local_id(0) == 0) {
175 // sum[0] /= ne00;
176 //}
177
178 //barrier(CLK_LOCAL_MEM_FENCE);
179
180 sumf = sum[get_sub_group_local_id()];
181 sumf = sub_group_reduce_add(sumf);
182
183 float mean = sumf / ne00;
184 float scale = 1.0f/sqrt(mean + eps);
185
186 global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
187 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
188 y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
189 }
190}