1//------------------------------------------------------------------------------
2// This file is contains kernels for data conversion.
3// These kernels are used when loading the model, so its performance is less
4// important.
5//------------------------------------------------------------------------------
6#pragma OPENCL EXTENSION cl_khr_fp16 : enable
7
8#ifdef cl_intel_required_subgroup_size
9#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
10#define INTEL_GPU 1
11#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
12#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
13#elif defined(cl_qcom_reqd_sub_group_size)
14#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
15#define ADRENO_GPU 1
16#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
17#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
18#endif
19
20#define QK4_0 32
21#define QR4_0 2
22#define QK4_1 32
23#define QR4_1 2
24#define QK5_0 32
25#define QR5_0 2
26#define QK5_1 32
27#define QR5_1 2
28#define QK8_0 32
29#define QR8_0 1
30#define QK_K 256
31#define K_QUANTS_PER_ITERATION 2
32
33typedef char int8_t;
34typedef uchar uint8_t;
35typedef short int16_t;
36typedef ushort uint16_t;
37typedef int int32_t;
38typedef uint uint32_t;
39
40//------------------------------------------------------------------------------
41// block_q4_0
42//------------------------------------------------------------------------------
43struct block_q4_0
44{
45 half d;
46 uint8_t qs[QK4_0 / 2];
47};
48
49//------------------------------------------------------------------------------
50// block_q6_K
51//------------------------------------------------------------------------------
52struct block_q6_K {
53 uint8_t ql[QK_K/2]; // quants, lower 4 bits
54 uint8_t qh[QK_K/4]; // quants, upper 2 bits
55 int8_t scales[QK_K/16]; // scales, quantized with 8 bits
56 half d; // super-block scale
57};
58
59//------------------------------------------------------------------------------
60// kernel_convert_block_q4_0
61// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
62// This kernel does not deshuffle the bits.
63//------------------------------------------------------------------------------
64kernel void kernel_convert_block_q4_0(
65 global struct block_q4_0 * src0,
66 global uchar * dst_q,
67 global half * dst_d
68) {
69 global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
70 global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
71 global half * d = (global half *) dst_d + get_global_id(0);
72
73 *d = b->d;
74
75 for (int i = 0; i < QK4_0/2; ++i) {
76 q[i] = b->qs[i];
77 }
78}
79
80kernel void kernel_restore_block_q4_0(
81 global uchar * src_q,
82 global half * src_d,
83 global struct block_q4_0 * dst
84) {
85 global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
86 global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
87 global half * d = (global half *) src_d + get_global_id(0);
88
89 b->d = *d;
90 for (int i = 0; i < QK4_0/2; ++i) {
91 b->qs[i] = q[i];
92 }
93}
94
95//------------------------------------------------------------------------------
96// kernel_convert_block_q4_0_noshuffle
97// Flatten q4_0 weights and unshuffle the bits
98//------------------------------------------------------------------------------
99
100kernel void kernel_convert_block_q4_0_noshuffle(
101 global struct block_q4_0 * src0,
102 global uchar * dst_q,
103 global half * dst_d
104) {
105 global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
106 global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
107 global half * d = (global half *) dst_d + get_global_id(0);
108
109 *d = b->d;
110 for (int i = 0; i < QK4_0/4; ++i) {
111 uchar x0 = b->qs[2*i + 0];
112 uchar x1 = b->qs[2*i + 1];
113
114 q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
115 q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
116
117#ifdef ADRENO_GPU
118 // Workaround for adreno - must have the following printf statement for
119 // the kernel to work properly. Otherwise it produces incorrect result.
120 // convert_uchar above also seems necessary.
121 // Compare against a large number so that it does not print anything.
122 // get_sub_group_local_id() also works.
123 if (get_global_id(0) == 65536*4096) {
124 printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
125 }
126#endif
127 }
128}
129
130kernel void kernel_restore_block_q4_0_noshuffle(
131 global uchar * src_q,
132 global half * src_d,
133 global struct block_q4_0 * dst,
134 uchar mask_0F,
135 uchar mask_F0
136) {
137 global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
138 global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
139 global half * d = (global half *) src_d + get_global_id(0);
140
141 b->d = *d;
142 for (int i = 0; i < QK4_0/4; ++i) {
143 uchar x0 = q[i + 0 ] ;
144 uchar x1 = q[i + QK4_0/4];
145
146 b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
147 b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
148 }
149}
150
151//------------------------------------------------------------------------------
152// block_mxfp4
153//------------------------------------------------------------------------------
154#define QK_MXFP4 32
155struct block_mxfp4 {
156 uchar e; // E8M0
157 uchar qs[QK_MXFP4 / 2];
158};
159
160//------------------------------------------------------------------------------
161// kernel_convert_block_mxfp4
162// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA).
163// This kernel does not deshuffle the bits.
164//------------------------------------------------------------------------------
165kernel void kernel_convert_block_mxfp4(
166 global struct block_mxfp4 * src0,
167 global uchar * dst_q,
168 global uchar * dst_e
169) {
170 global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0);
171 global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0);
172 global uchar * e = (global uchar *) dst_e + get_global_id(0);
173
174 *e = b->e;
175
176 for (int i = 0; i < QK_MXFP4 / 2; ++i) {
177 q[i] = b->qs[i];
178 }
179}
180
181kernel void kernel_convert_block_mxfp4_trans(
182 global struct block_mxfp4 * src0,
183 __global uint4 * dst_q,
184 __global uchar * dst_e,
185 uint ne00,
186 uint ne01
187) {
188 int i00 = get_global_id(1);
189 uint i01 = get_global_id(0);
190 uint i02 = get_global_id(2);
191
192 uint ne00_blk = ne00 / QK_MXFP4;
193 uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
194 uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
195
196 global struct block_mxfp4 * b = src0 + src_blk_offset;
197
198 dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
199 dst_e[dst_blk_offset] = b->e;
200}
201
202kernel void kernel_restore_block_mxfp4(
203 global uchar * src_q,
204 global half * src_e,
205 global struct block_mxfp4 * dst
206) {
207 global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0);
208 global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0);
209 global uchar * e = (global uchar *) src_e + get_global_id(0);
210
211 b->e = *e;
212 for (int i = 0; i < QK_MXFP4 / 2; ++i) {
213 b->qs[i] = q[i];
214 }
215}
216
217kernel void kernel_restore_block_mxfp4_trans(
218 __global uint4 * src_q,
219 __global uchar * src_e,
220 global struct block_mxfp4 * dst,
221 uint ne00,
222 uint ne01
223) {
224 int i00 = get_global_id(1);
225 uint i01 = get_global_id(0);
226 uint i02 = get_global_id(2);
227
228 uint ne00_blk = ne00 / QK_MXFP4;
229 uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
230 uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
231
232 global struct block_mxfp4 * b = dst + dst_blk_offset;
233
234 ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
235 b->e = src_e[src_blk_offset];
236}
237
238//------------------------------------------------------------------------------
239// block_q8_0
240//------------------------------------------------------------------------------
241typedef struct {
242 half d; // delta
243 char qs[QK8_0]; // quants
244} block_q8_0;
245
246kernel void kernel_convert_block_q8_0(
247 global block_q8_0 * src0,
248 global uchar * dst_q,
249 global half * dst_d
250) {
251 global block_q8_0 * b = (global block_q8_0 *) src0 + get_global_id(0);
252 global uchar * q = (global uchar *) dst_q + QK8_0*get_global_id(0);
253 global half * d = (global half *) dst_d + get_global_id(0);
254
255 *d = b->d;
256
257 for (int i = 0; i < QK8_0; ++i) {
258 q[i] = b->qs[i];
259 }
260}
261
262kernel void kernel_restore_block_q8_0(
263 global uchar * src_q,
264 global half * src_d,
265 global block_q8_0 * dst
266) {
267 global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0);
268 global uchar * q = (global uchar *) src_q + QK8_0*get_global_id(0);
269 global half * d = (global half *) src_d + get_global_id(0);
270
271 b->d = *d;
272 for (int i = 0; i < QK8_0; ++i) {
273 b->qs[i] = q[i];
274 }
275}
276
277kernel void kernel_restore_block_q8_0_trans(
278 global uchar * src_q,
279 global half * src_d,
280 global block_q8_0 * dst,
281 uint ne00,
282 uint ne01
283){
284 uint num_blk_per_row = ne00 / QK8_0;
285
286 global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row;
287 global uchar * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed
288 global half * d = (global half *) src_d + get_global_id(0);
289
290 for (uint blk = 0; blk < num_blk_per_row; blk++) {
291 b->d = *d;
292
293 for (uint i = 0; i < QK8_0; i+=4) {
294 b->qs[i] = q[0];
295 b->qs[i+1] = q[1];
296 b->qs[i+2] = q[2];
297 b->qs[i+3] = q[3];
298
299 q += 4 * ne01; // M stride
300 }
301
302 d += ne01;
303
304 b++;
305 }
306}
307
308//------------------------------------------------------------------------------
309// kernel_convert_block_q6_K
310// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
311// This kernel does not deshuffle the bits.
312// Each thread processes a super block.
313//------------------------------------------------------------------------------
314kernel void kernel_convert_block_q6_K(
315 global struct block_q6_K * src0,
316 global uchar * dst_ql,
317 global uchar * dst_qh,
318 global char * dst_s,
319 global half * dst_d
320) {
321 global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
322 global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
323 global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
324 global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
325 global half * d = (global half *) dst_d + get_global_id(0);
326
327 *d = b->d;
328
329 for (int i = 0; i < QK_K/2; ++i) {
330 ql[i] = b->ql[i];
331 }
332 for (int i = 0; i < QK_K/4; ++i) {
333 qh[i] = b->qh[i];
334 }
335 for (int i = 0; i < QK_K/16; ++i) {
336 s[i] = b->scales[i];
337 }
338}
339
340// Restore block_q6_K from flattened arrays.
341// Each thread processes a super block.
342kernel void kernel_restore_block_q6_K(
343 global uchar * dst_ql,
344 global uchar * dst_qh,
345 global char * dst_s,
346 global half * dst_d,
347 global struct block_q6_K * dst
348) {
349 global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
350 global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
351 global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
352 global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
353 global half * d = (global half *) dst_d + get_global_id(0);
354
355 b->d = *d;
356
357 for (int i = 0; i < QK_K/2; ++i) {
358 b->ql[i] = ql[i];
359 }
360 for (int i = 0; i < QK_K/4; ++i) {
361 b->qh[i] = qh[i];
362 }
363 for (int i = 0; i < QK_K/16; ++i) {
364 b->scales[i] = s[i];
365 }
366}