1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
3
4#ifdef cl_qcom_reqd_sub_group_size
5#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
6#define ADRENO_GPU 1
7#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
8#endif
9
10#define QK8_0 32
11#define N_SIMDGROUP 4
12
13#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \
14 float shared_y; \
15 char elem; \
16 \
17 shared_y = sub_group_broadcast(y.s0, 0); \
18 elem = (char)(bits8.s0 & 0x000000FF); \
19 total_sums += convert_int(elem) * scale * shared_y; \
20 shared_y = sub_group_broadcast(y.s1, 0); \
21 elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \
22 total_sums += convert_int(elem) * scale * shared_y; \
23 shared_y = sub_group_broadcast(y.s2, 0); \
24 elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \
25 total_sums += convert_int(elem) * scale * shared_y; \
26 shared_y = sub_group_broadcast(y.s3, 0); \
27 elem = (char)((bits8.s0 & 0xFF000000) >> 24); \
28 total_sums += convert_int(elem) * scale * shared_y; \
29 \
30 shared_y = sub_group_broadcast(y.s4, 0); \
31 elem = (char)(bits8.s1 & 0x000000FF); \
32 total_sums += convert_int(elem) * scale * shared_y; \
33 shared_y = sub_group_broadcast(y.s5, 0); \
34 elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \
35 total_sums += convert_int(elem) * scale * shared_y; \
36 shared_y = sub_group_broadcast(y.s6, 0); \
37 elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \
38 total_sums += convert_int(elem) * scale * shared_y; \
39 shared_y = sub_group_broadcast(y.s7, 0); \
40 elem = (char)((bits8.s1 & 0xFF000000) >> 24); \
41 total_sums += convert_int(elem) * scale * shared_y; \
42 \
43 shared_y = sub_group_broadcast(y.s0, 1); \
44 elem = (char)(bits8.s2 & 0x000000FF); \
45 total_sums += convert_int(elem) * scale * shared_y; \
46 shared_y = sub_group_broadcast(y.s1, 1); \
47 elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \
48 total_sums += convert_int(elem) * scale * shared_y; \
49 shared_y = sub_group_broadcast(y.s2, 1); \
50 elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \
51 total_sums += convert_int(elem) * scale * shared_y; \
52 shared_y = sub_group_broadcast(y.s3, 1); \
53 elem = (char)((bits8.s2 & 0xFF000000) >> 24); \
54 total_sums += convert_int(elem) * scale * shared_y; \
55 \
56 shared_y = sub_group_broadcast(y.s4, 1); \
57 elem = (char)(bits8.s3 & 0x000000FF); \
58 total_sums += convert_int(elem) * scale * shared_y; \
59 shared_y = sub_group_broadcast(y.s5, 1); \
60 elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \
61 total_sums += convert_int(elem) * scale * shared_y; \
62 shared_y = sub_group_broadcast(y.s6, 1); \
63 elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \
64 total_sums += convert_int(elem) * scale * shared_y; \
65 shared_y = sub_group_broadcast(y.s7, 1); \
66 elem = (char)((bits8.s3 & 0xFF000000) >> 24); \
67 total_sums += convert_int(elem) * scale * shared_y; \
68 \
69 shared_y = sub_group_broadcast(y.s0, 2); \
70 elem = (char)(bits8.s4 & 0x000000FF); \
71 total_sums += convert_int(elem) * scale * shared_y; \
72 shared_y = sub_group_broadcast(y.s1, 2); \
73 elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \
74 total_sums += convert_int(elem) * scale * shared_y; \
75 shared_y = sub_group_broadcast(y.s2, 2); \
76 elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \
77 total_sums += convert_int(elem) * scale * shared_y; \
78 shared_y = sub_group_broadcast(y.s3, 2); \
79 elem = (char)((bits8.s4 & 0xFF000000) >> 24); \
80 total_sums += convert_int(elem) * scale * shared_y; \
81 \
82 shared_y = sub_group_broadcast(y.s4, 2); \
83 elem = (char)(bits8.s5 & 0x000000FF); \
84 total_sums += convert_int(elem) * scale * shared_y; \
85 shared_y = sub_group_broadcast(y.s5, 2); \
86 elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \
87 total_sums += convert_int(elem) * scale * shared_y; \
88 shared_y = sub_group_broadcast(y.s6, 2); \
89 elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \
90 total_sums += convert_int(elem) * scale * shared_y; \
91 shared_y = sub_group_broadcast(y.s7, 2); \
92 elem = (char)((bits8.s5 & 0xFF000000) >> 24); \
93 total_sums += convert_int(elem) * scale * shared_y; \
94 \
95 shared_y = sub_group_broadcast(y.s0, 3); \
96 elem = (char)(bits8.s6 & 0x000000FF); \
97 total_sums += convert_int(elem) * scale * shared_y; \
98 shared_y = sub_group_broadcast(y.s1, 3); \
99 elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \
100 total_sums += convert_int(elem) * scale * shared_y; \
101 shared_y = sub_group_broadcast(y.s2, 3); \
102 elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \
103 total_sums += convert_int(elem) * scale * shared_y; \
104 shared_y = sub_group_broadcast(y.s3, 3); \
105 elem = (char)((bits8.s6 & 0xFF000000) >> 24); \
106 total_sums += convert_int(elem) * scale * shared_y; \
107 \
108 shared_y = sub_group_broadcast(y.s4, 3); \
109 elem = (char)(bits8.s7 & 0x000000FF); \
110 total_sums += convert_int(elem) * scale * shared_y; \
111 shared_y = sub_group_broadcast(y.s5, 3); \
112 elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \
113 total_sums += convert_int(elem) * scale * shared_y; \
114 shared_y = sub_group_broadcast(y.s6, 3); \
115 elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \
116 total_sums += convert_int(elem) * scale * shared_y; \
117 shared_y = sub_group_broadcast(y.s7, 3); \
118 elem = (char)((bits8.s7 & 0xFF000000) >> 24); \
119 total_sums += convert_int(elem) * scale * shared_y; \
120
121#ifdef ADRENO_GPU
122REQD_SUBGROUP_SIZE_64
123#endif
124__kernel void kernel_gemv_noshuffle(
125 __read_only image1d_buffer_t src0_q, // quantized A
126 global half * src0_d, // A scales
127 __read_only image1d_buffer_t src1, // B
128 ulong offset1, // offset to B (0)
129 global float * dst, // C
130 ulong offsetd, // offset to C
131 int ne00, // K
132 int ne01, // M
133 int ne02, // 1
134 int ne10, // K
135 int ne12, // 1
136 int ne0, // M
137 int ne1, // N
138 int r2, // 1
139 int r3)
140{
141 uint groupId = get_local_id(1);
142 uint gid = get_global_id(0);
143 ushort slid = get_sub_group_local_id();
144
145 uint K = ne00;
146 uint M = ne01;
147
148 uint LINE_STRIDE_A = M;
149 uint BLOCK_STRIDE_A = 8 * M; // 32 / 4 = 8
150
151 __private uint8 regA;
152 __private half regS;
153 __private float8 regB;
154
155 __private float totalSum = (float)(0.0f);
156
157 // loop along K in block granularity, skip 4 blocks every iter
158 #pragma unroll 1 /* tell compiler not to unroll */
159 for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) {
160 regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows
161 // first 4 fibers in each wave load 8 B values to its private scope
162 if (slid < 4) {
163 regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
164 regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
165 }
166
167 // load weights for one block in consecutive rows
168 regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
169 regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
170 regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
171 regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
172 regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
173 regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
174 regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
175 regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
176
177 dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
178 }
179
180 // reduction in local memory, assumes #wave=4
181 __local float reduceLM[SIMDGROUP_WIDTH * 3];
182 if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
183 if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
184 if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
185 barrier(CLK_LOCAL_MEM_FENCE);
186 if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
187 if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
188 if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
189
190 // 1 outputs per fiber in wave 0
191 if (groupId == 0) {
192 dst = (global float*)((global char*)dst + offsetd);
193 dst[gid] = totalSum;
194 }
195}