1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
3
4#ifdef cl_qcom_reqd_sub_group_size
5#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
6#define ADRENO_GPU 1
7#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
8#endif
9
10// assume
11#define QK4_0 32
12#define N_SIMDGROUP 4
13
14#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
15 float shared_y; \
16 shared_y = sub_group_broadcast(y.s0, 0); \
17 total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
18 total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
19 shared_y = sub_group_broadcast(y.s1, 0); \
20 total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
21 total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
22 shared_y = sub_group_broadcast(y.s2, 0); \
23 total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
24 total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
25 shared_y = sub_group_broadcast(y.s3, 0); \
26 total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
27 total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
28 shared_y = sub_group_broadcast(y.s4, 0); \
29 total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
30 total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
31 shared_y = sub_group_broadcast(y.s5, 0); \
32 total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
33 total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
34 shared_y = sub_group_broadcast(y.s6, 0); \
35 total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
36 total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
37 shared_y = sub_group_broadcast(y.s7, 0); \
38 total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
39 total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
40 shared_y = sub_group_broadcast(y.s0, 1); \
41 total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
42 total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
43 shared_y = sub_group_broadcast(y.s1, 1); \
44 total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
45 total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
46 shared_y = sub_group_broadcast(y.s2, 1); \
47 total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
48 total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
49 shared_y = sub_group_broadcast(y.s3, 1); \
50 total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
51 total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
52 shared_y = sub_group_broadcast(y.s4, 1); \
53 total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
54 total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
55 shared_y = sub_group_broadcast(y.s5, 1); \
56 total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
57 total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
58 shared_y = sub_group_broadcast(y.s6, 1); \
59 total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
60 total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
61 shared_y = sub_group_broadcast(y.s7, 1); \
62 total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
63 total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
64
65
66#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
67 shared_y = sub_group_broadcast(y.s0, 2); \
68 total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
69 total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
70 shared_y = sub_group_broadcast(y.s1, 2); \
71 total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
72 total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
73 shared_y = sub_group_broadcast(y.s2, 2); \
74 total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
75 total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
76 shared_y = sub_group_broadcast(y.s3, 2); \
77 total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
78 total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
79 shared_y = sub_group_broadcast(y.s4, 2); \
80 total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
81 total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
82 shared_y = sub_group_broadcast(y.s5, 2); \
83 total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
84 total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
85 shared_y = sub_group_broadcast(y.s6, 2); \
86 total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
87 total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
88 shared_y = sub_group_broadcast(y.s7, 2); \
89 total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
90 total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
91 shared_y = sub_group_broadcast(y.s0, 3); \
92 total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
93 total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
94 shared_y = sub_group_broadcast(y.s1, 3); \
95 total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
96 total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
97 shared_y = sub_group_broadcast(y.s2, 3); \
98 total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
99 total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
100 shared_y = sub_group_broadcast(y.s3, 3); \
101 total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
102 total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
103 shared_y = sub_group_broadcast(y.s4, 3); \
104 total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
105 total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
106 shared_y = sub_group_broadcast(y.s5, 3); \
107 total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
108 total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
109 shared_y = sub_group_broadcast(y.s6, 3); \
110 total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
111 total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
112 shared_y = sub_group_broadcast(y.s7, 3); \
113 total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
114 total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
115
116
117#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
118 float8 shared_y; \
119 shared_y = sub_group_broadcast(y, 0); \
120 total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
121 total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
122 total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
123 total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
124 total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
125 total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
126 total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
127 total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
128 total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
129 total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
130 total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
131 total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
132 total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
133 total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
134 total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
135 total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
136 shared_y = sub_group_broadcast(y, 1); \
137 total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
138 total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
139 total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
140 total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
141 total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
142 total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
143 total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
144 total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
145 total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
146 total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
147 total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
148 total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
149 total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
150 total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
151 total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
152 total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
153
154
155#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
156 shared_y = sub_group_broadcast(y, 2); \
157 total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
158 total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
159 total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
160 total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
161 total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
162 total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
163 total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
164 total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
165 total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
166 total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
167 total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
168 total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
169 total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
170 total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
171 total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
172 total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
173 shared_y = sub_group_broadcast(y, 3); \
174 total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
175 total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
176 total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
177 total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
178 total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
179 total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
180 total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
181 total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
182 total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
183 total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
184 total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
185 total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
186 total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
187 total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
188 total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
189 total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
190
191#ifdef ADRENO_GPU
192REQD_SUBGROUP_SIZE_64
193#endif
194__kernel void kernel_gemv_noshuffle(
195 __read_only image1d_buffer_t src0_q, // quantized A
196 global half2 * src0_d, // A scales
197 __read_only image1d_buffer_t src1, // B
198 ulong offset1, // offset to B (0)
199 global float * dst, // C
200 ulong offsetd, // offset to C (0)
201 int ne00, // K
202 int ne01, // M
203 int ne02, // 1
204 int ne10, // K
205 int ne12, // 1
206 int ne0, // M
207 int ne1, // N
208 int r2, // 1
209 int r3)
210{
211 uint groupId = get_local_id(1);
212 uint gid = get_global_id(0);
213 ushort slid = get_sub_group_local_id();
214
215 uint K = ne00;
216 uint M = ne01;
217
218 uint LINE_STRIDE_A = M / 2;
219 uint BLOCK_STRIDE_A = N_SIMDGROUP * M;
220
221 __private uint4 regA;
222 __private half2 regS;
223 __private float8 regB;
224
225 __private float2 totalSum = (float2)(0.0f);
226
227 // loop along K in block granularity, skip 4 blocks every iter
228 for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
229 regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
230 // first 4 fibers in each wave load 8 B values to its private scope
231 if (slid < 4) {
232 regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
233 regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
234 }
235
236 // load half weights for two blocks in consecutive rows
237 regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
238 regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
239 regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
240 regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
241#ifdef VECTOR_SUB_GROUP_BROADCAT
242 dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
243#else
244 dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
245#endif // VECTOR_SUB_GROUP_BROADCAT
246
247 regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
248 regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
249 regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
250 regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
251#ifdef VECTOR_SUB_GROUP_BROADCAT
252 dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
253#else
254 dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
255#endif // VECTOR_SUB_GROUP_BROADCAT
256 }
257
258 // reduction in local memory, assumes #wave=4
259 __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
260 if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
261 if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
262 if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
263 barrier(CLK_LOCAL_MEM_FENCE);
264 if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
265 if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
266 if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
267
268 // 2 outputs per fiber in wave 0
269 if (groupId == 0) {
270 dst = (global float*)((global char*)dst + offsetd);
271 vstore2(totalSum, 0, &(dst[gid * 2]));
272 }
273
274}