1//------------------------------------------------------------------------------
  2// This file is contains kernels for data conversion.
  3// These kernels are used when loading the model, so its performance is less
  4// important.
  5//------------------------------------------------------------------------------
  6#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  7
  8#ifdef cl_intel_required_subgroup_size
  9#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 10#define INTEL_GPU 1
 11#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 12#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 13#elif defined(cl_qcom_reqd_sub_group_size)
 14#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 15#define ADRENO_GPU 1
 16#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 17#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 18#endif
 19
 20#define QK4_0                   32
 21#define QR4_0                   2
 22#define QK4_1                   32
 23#define QR4_1                   2
 24#define QK5_0                   32
 25#define QR5_0                   2
 26#define QK5_1                   32
 27#define QR5_1                   2
 28#define QK8_0                   32
 29#define QR8_0                   1
 30#define QK_K                    256
 31#define K_QUANTS_PER_ITERATION  2
 32
 33typedef char int8_t;
 34typedef uchar uint8_t;
 35typedef short int16_t;
 36typedef ushort uint16_t;
 37typedef int int32_t;
 38typedef uint uint32_t;
 39
 40//------------------------------------------------------------------------------
 41// block_q4_0
 42//------------------------------------------------------------------------------
 43struct block_q4_0
 44{
 45    half d;
 46    uint8_t qs[QK4_0 / 2];
 47};
 48
 49//------------------------------------------------------------------------------
 50// block_q6_K
 51//------------------------------------------------------------------------------
 52struct block_q6_K {
 53    uint8_t ql[QK_K/2];      // quants, lower 4 bits
 54    uint8_t qh[QK_K/4];      // quants, upper 2 bits
 55    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
 56    half d;                  // super-block scale
 57};
 58
 59//------------------------------------------------------------------------------
 60// kernel_convert_block_q4_0
 61// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
 62// This kernel does not deshuffle the bits.
 63//------------------------------------------------------------------------------
 64kernel void kernel_convert_block_q4_0(
 65    global struct block_q4_0 * src0,
 66    global uchar * dst_q,
 67    global half  * dst_d
 68) {
 69    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
 70    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
 71    global half  * d = (global half *) dst_d + get_global_id(0);
 72
 73    *d = b->d;
 74
 75    for (int i = 0; i < QK4_0/2; ++i) {
 76        q[i] = b->qs[i];
 77    }
 78}
 79
 80kernel void kernel_restore_block_q4_0(
 81    global uchar * src_q,
 82    global half  * src_d,
 83    global struct block_q4_0 * dst
 84) {
 85    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
 86    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
 87    global half  * d = (global half *) src_d + get_global_id(0);
 88
 89    b->d = *d;
 90    for (int i = 0; i < QK4_0/2; ++i) {
 91        b->qs[i] = q[i];
 92    }
 93}
 94
 95//------------------------------------------------------------------------------
 96// kernel_convert_block_q4_0_noshuffle
 97// Flatten q4_0 weights and unshuffle the bits
 98//------------------------------------------------------------------------------
 99
100kernel void kernel_convert_block_q4_0_noshuffle(
101    global struct block_q4_0 * src0,
102    global uchar * dst_q,
103    global half  * dst_d
104) {
105    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
106    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
107    global half  * d = (global half *) dst_d + get_global_id(0);
108
109    *d = b->d;
110    for (int i = 0; i < QK4_0/4; ++i) {
111        uchar x0 = b->qs[2*i + 0];
112        uchar x1 = b->qs[2*i + 1];
113
114        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
115        q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
116
117#ifdef ADRENO_GPU
118        // Workaround for adreno - must have the following printf statement for
119        // the kernel to work properly. Otherwise it produces incorrect result.
120        // convert_uchar above also seems necessary.
121        // Compare against a large number so that it does not print anything.
122        // get_sub_group_local_id() also works.
123        if (get_global_id(0) == 65536*4096) {
124            printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
125        }
126#endif
127    }
128}
129
130kernel void kernel_restore_block_q4_0_noshuffle(
131    global uchar * src_q,
132    global half  * src_d,
133    global struct block_q4_0 * dst,
134    uchar mask_0F,
135    uchar mask_F0
136) {
137    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
138    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
139    global half  * d = (global half *) src_d + get_global_id(0);
140
141    b->d = *d;
142    for (int i = 0; i < QK4_0/4; ++i) {
143        uchar x0 = q[i + 0      ] ;
144        uchar x1 = q[i + QK4_0/4];
145
146        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
147        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
148    }
149}
150
151//------------------------------------------------------------------------------
152// block_mxfp4
153//------------------------------------------------------------------------------
154#define QK_MXFP4 32
155struct block_mxfp4 {
156    uchar e; // E8M0
157    uchar qs[QK_MXFP4 / 2];
158};
159
160//------------------------------------------------------------------------------
161// kernel_convert_block_mxfp4
162// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA).
163// This kernel does not deshuffle the bits.
164//------------------------------------------------------------------------------
165kernel void kernel_convert_block_mxfp4(
166    global struct block_mxfp4 * src0,
167    global uchar * dst_q,
168    global uchar * dst_e
169) {
170    global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0);
171    global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0);
172    global uchar * e = (global uchar *) dst_e + get_global_id(0);
173
174    *e = b->e;
175
176    for (int i = 0; i < QK_MXFP4 / 2; ++i) {
177        q[i] = b->qs[i];
178    }
179}
180
181kernel void kernel_convert_block_mxfp4_trans(
182    global struct block_mxfp4 * src0,
183    __global uint4 * dst_q,
184    __global uchar * dst_e,
185    uint ne00,
186    uint ne01
187) {
188    int i00 = get_global_id(1);
189    uint i01 = get_global_id(0);
190    uint i02 = get_global_id(2);
191
192    uint ne00_blk = ne00 / QK_MXFP4;
193    uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
194    uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
195
196    global struct block_mxfp4 * b = src0 + src_blk_offset;
197
198    dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
199    dst_e[dst_blk_offset] = b->e;
200}
201
202kernel void kernel_restore_block_mxfp4(
203    global uchar * src_q,
204    global half  * src_e,
205    global struct block_mxfp4 * dst
206) {
207    global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0);
208    global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0);
209    global uchar * e = (global uchar *) src_e + get_global_id(0);
210
211    b->e = *e;
212    for (int i = 0; i < QK_MXFP4 / 2; ++i) {
213        b->qs[i] = q[i];
214    }
215}
216
217kernel void kernel_restore_block_mxfp4_trans(
218    __global uint4 * src_q,
219    __global uchar * src_e,
220    global struct block_mxfp4 * dst,
221    uint ne00,
222    uint ne01
223) {
224    int i00 = get_global_id(1);
225    uint i01 = get_global_id(0);
226    uint i02 = get_global_id(2);
227
228    uint ne00_blk = ne00 / QK_MXFP4;
229    uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
230    uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
231
232    global struct block_mxfp4 * b = dst + dst_blk_offset;
233
234    ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
235    b->e = src_e[src_blk_offset];
236}
237
238//------------------------------------------------------------------------------
239// block_q8_0
240//------------------------------------------------------------------------------
241typedef struct {
242    half d;       // delta
243    char qs[QK8_0]; // quants
244} block_q8_0;
245
246kernel void kernel_convert_block_q8_0(
247    global block_q8_0 * src0,
248    global uchar * dst_q,
249    global half  * dst_d
250) {
251    global block_q8_0 * b = (global block_q8_0 *) src0 + get_global_id(0);
252    global uchar      * q = (global uchar *) dst_q + QK8_0*get_global_id(0);
253    global half       * d = (global half *) dst_d + get_global_id(0);
254
255    *d = b->d;
256
257    for (int i = 0; i < QK8_0; ++i) {
258        q[i] = b->qs[i];
259    }
260}
261
262kernel void kernel_restore_block_q8_0(
263    global uchar * src_q,
264    global half  * src_d,
265    global block_q8_0 * dst
266) {
267    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0);
268    global uchar      * q = (global uchar *) src_q + QK8_0*get_global_id(0);
269    global half       * d = (global half *) src_d + get_global_id(0);
270
271    b->d = *d;
272    for (int i = 0; i < QK8_0; ++i) {
273        b->qs[i] = q[i];
274    }
275}
276
277kernel void kernel_restore_block_q8_0_trans(
278    global uchar * src_q,
279    global half  * src_d,
280    global block_q8_0 * dst,
281    uint ne00,
282    uint ne01
283){
284    uint num_blk_per_row = ne00 / QK8_0;
285
286    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row;
287    global uchar      * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed
288    global half       * d = (global half *) src_d + get_global_id(0);
289
290    for (uint blk = 0; blk < num_blk_per_row; blk++) {
291        b->d = *d;
292
293        for (uint i = 0; i < QK8_0; i+=4) {
294            b->qs[i]   = q[0];
295            b->qs[i+1] = q[1];
296            b->qs[i+2] = q[2];
297            b->qs[i+3] = q[3];
298
299            q += 4 * ne01; // M stride
300        }
301
302        d += ne01;
303
304        b++;
305    }
306}
307
308//------------------------------------------------------------------------------
309// kernel_convert_block_q6_K
310// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
311// This kernel does not deshuffle the bits.
312// Each thread processes a super block.
313//------------------------------------------------------------------------------
314kernel void kernel_convert_block_q6_K(
315    global struct block_q6_K * src0,
316    global uchar * dst_ql,
317    global uchar * dst_qh,
318    global char  * dst_s,
319    global half  * dst_d
320) {
321    global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
322    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
323    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
324    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
325    global half  * d  = (global half  *) dst_d  + get_global_id(0);
326
327    *d = b->d;
328
329    for (int i = 0; i < QK_K/2; ++i) {
330        ql[i] = b->ql[i];
331    }
332    for (int i = 0; i < QK_K/4; ++i) {
333        qh[i] = b->qh[i];
334    }
335    for (int i = 0; i < QK_K/16; ++i) {
336        s[i] = b->scales[i];
337    }
338}
339
340// Restore block_q6_K from flattened arrays.
341// Each thread processes a super block.
342kernel void kernel_restore_block_q6_K(
343    global uchar * dst_ql,
344    global uchar * dst_qh,
345    global char  * dst_s,
346    global half  * dst_d,
347    global struct block_q6_K * dst
348) {
349    global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
350    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
351    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
352    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
353    global half  * d  = (global half  *) dst_d  + get_global_id(0);
354
355    b->d = *d;
356
357    for (int i = 0; i < QK_K/2; ++i) {
358        b->ql[i] = ql[i];
359    }
360    for (int i = 0; i < QK_K/4; ++i) {
361        b->qh[i] = qh[i];
362    }
363    for (int i = 0; i < QK_K/16; ++i) {
364        b->scales[i] = s[i];
365    }
366}