1
  2#include "types.glsl"
  3
  4layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
  5   vec4 block;
  6};
  7
  8float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  9{
 10    const vec4 v = bl.block;
 11    const uint idx = coordInBlock[1];
 12    const f16vec4 vf16 = f16vec4(v);
 13    return vf16[idx];
 14}
 15
 16layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
 17   block_q4_0_packed16 block;
 18};
 19
 20float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 21{
 22    const float16_t d = bl.block.d;
 23    const uint idx = coordInBlock[1];
 24    const uint shift = (idx & 0x10) >> 2;
 25    uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
 26    qs >>= shift;
 27    qs &= 0x0F0F;
 28    qs = unpack8(qs)[idx & 1];
 29    float16_t ret = (float16_t(qs) - float16_t(8)) * d;
 30    return ret;
 31}
 32
 33layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
 34   block_q4_1 block;
 35};
 36
 37float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 38{
 39    const float16_t d = bl.block.d;
 40    const float16_t m = bl.block.m;
 41    const uint idx = coordInBlock[1];
 42    const uint iqs = idx & 0xF;
 43    const uint shift = (idx & 0x10) >> 2;
 44    uint32_t qs = bl.block.qs[iqs];
 45    qs >>= shift;
 46    qs &= 0xF;
 47    float16_t ret = float16_t(qs) * d + m;
 48    return ret;
 49}
 50
 51layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
 52   block_q5_0 block;
 53};
 54
 55float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 56{
 57    const float16_t d = bl.block.d;
 58    const uint idx = coordInBlock[1];
 59    const uint iqs = idx & 0xF;
 60
 61    const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
 62    const uint qh = ((uint_qh >> idx) << 4) & 0x10;
 63
 64    const uint shift = (idx & 0x10) >> 2;
 65    uint32_t qs = bl.block.qs[iqs];
 66    qs >>= shift;
 67    qs &= 0xF;
 68
 69    float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
 70    return ret;
 71}
 72
 73layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
 74   block_q5_1 block;
 75};
 76
 77float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 78{
 79    const float16_t d = bl.block.d;
 80    const float16_t m = bl.block.m;
 81    const uint idx = coordInBlock[1];
 82    const uint iqs = idx & 0xF;
 83
 84    const uint uint_qh = bl.block.qh;
 85    const uint qh = ((uint_qh >> idx) << 4) & 0x10;
 86
 87    const uint shift = (idx & 0x10) >> 2;
 88    uint32_t qs = bl.block.qs[iqs];
 89    qs >>= shift;
 90    qs &= 0xF;
 91
 92    float16_t ret = float16_t(qs | qh) * d + m;
 93    return ret;
 94}
 95
 96layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
 97   block_q8_0_packed16 block;
 98};
 99
100float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
101{
102    const float16_t d = bl.block.d;
103    const uint idx = coordInBlock[1];
104    const uint iqs = idx;
105
106    // Load 16b and select the byte for this element
107    int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
108    float16_t ret = float16_t(qs) * d;
109    return ret;
110}
111
112layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
113   block_q2_K block;
114};
115
116layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
117   block_q2_K_packed16 block;
118};
119
120float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
121{
122    decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
123    const f16vec2 dm = bl.block.dm;
124    const uint idx = coordInBlock[1];
125
126    const uint scalesi = (idx & 0xF0) >> 4;             // 0..15
127    const uint qsshift = (idx & 0x60) >> 4;             // 0,2,4,6
128
129    uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
130    qs = (qs >> qsshift) & 0x0303;
131    qs = unpack8(qs)[idx & 1];
132
133    const uint scales = bl.block.scales[scalesi];
134    float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
135    return ret;
136}
137
138layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
139   block_q3_K block;
140};
141
142float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
143{
144    const uint idx = coordInBlock[1];
145    const uint iqs = idx;
146
147    const uint n = iqs / 128;                    // 0,1
148    const uint qsi = n * 32 + (iqs % 32);        // 0..63
149    const uint hmi =          (iqs % 32);        // 0..31
150    const uint j = (iqs % 128) / 8;              // 0..15
151    const uint is = iqs / 16;                    // 0..15
152    const uint halfsplit = ((iqs % 128) / 32);   // 0,1,2,3
153    const uint qsshift = halfsplit * 2;          // 0,2,4,6
154    const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128
155
156    uint32_t scaleidx0 = (is < 8) ? is : (is-8);
157    uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
158    uint32_t scaleidx1 = is + 8 - (is/4)*4;
159    uint32_t scaleidx1shift = (is/4)*2;
160
161    const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
162
163    const float16_t dl = bl.block.d * float16_t(us - 32);
164
165    float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi    ] >> qsshift) & 3) - (((bl.block.hmask[hmi    ] & m) != 0) ? 0 : 4));
166
167    return ret;
168}
169
170layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
171   block_q4_K block;
172};
173
174layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
175   block_q4_K_packed16 block;
176};
177
178layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
179   block_q4_K_packed128 block;
180};
181
182#if defined(IS_MUL_MM2)
183
184// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
185// into shared memory and then process the whole tile using those scales.
186// There is a fetch function that loads into private variables and then a store
187// function that stores into shared memory.
188// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
189// the part that fetches from the structure (which has a different block layout).
190#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
191const uint shAscales_stride = (BM + 2);
192// 1 scale per 32 elements -> 8 scales per block, per row
193shared vec2 shAscales[8 * shAscales_stride];
194uvec4 row_v;
195#endif
196
197#if defined(DATA_A_Q4_K)
198layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
199
200void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
201{
202    uint tids_per_row = BLOCK_SIZE / BM;
203    uint is_per_tid = 8 / tids_per_row;
204    uint is_start = is_per_tid * (tid % tids_per_row);
205    uint tid_row = tid / tids_per_row;
206
207    uint row = ir_BM + tid_row;
208    uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
209    if (in_bounds || row < p.M) {
210        row_v = data_a_q4_k_packed128[block_index].q4k[0];
211    }
212}
213#endif
214#if defined(DATA_A_Q5_K)
215layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
216
217void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
218{
219    uint tids_per_row = BLOCK_SIZE / BM;
220    uint is_per_tid = 8 / tids_per_row;
221    uint is_start = is_per_tid * (tid % tids_per_row);
222    uint tid_row = tid / tids_per_row;
223
224    uint row = ir_BM + tid_row;
225    uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
226    if (in_bounds || row < p.M) {
227        row_v = data_a_q5_k_packed128[block_index].q5k[0];
228    }
229}
230#endif
231
232#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
233void store_scalesQ4_K(uint tid)
234{
235    barrier();
236
237    uint tids_per_row = BLOCK_SIZE / BM;
238    uint is_per_tid = 8 / tids_per_row;
239    uint is_start = is_per_tid * (tid % tids_per_row);
240    uint tid_row = tid / tids_per_row;
241
242    [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
243        uint is = idx + is_start;
244        uvec4 v = row_v;
245        const vec2 loadd = vec2(unpackFloat2x16(v.x));
246
247        uint32_t sc;
248        uint32_t mbyte;
249
250        uint32_t scale0 = v.y;
251        uint32_t scale4 = v.z;
252        uint32_t scale8 = v.w;
253
254        uint32_t sc_lo = scale0;
255        uint32_t mb_lo = scale4;
256        uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
257        uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
258
259        sc = is < 4 ? sc_lo : sc_hi;
260        mbyte = is < 4 ? mb_lo : mb_hi;
261        sc = sc >> (8 * (is & 3));
262        mbyte = mbyte >> (8 * (is & 3));
263        sc &= 0x3F;
264        mbyte &= 0x3F;
265
266        const float d = loadd.x * float(sc);
267        const float m = loadd.y * float(mbyte);
268        shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
269    }
270
271    barrier();
272}
273#endif
274
275#endif
276
277float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
278{
279    decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
280    decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
281    const uint idx = coordInBlock[1];
282
283    const uint b = (idx & 0x20) >> 5;            // 0,1
284    const uint is = (idx & 0xE0) >> 5;         // 0..7
285
286#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
287    vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
288    float d = v.x;
289    float m = v.y;
290#else
291    uvec4 v = bl128.block.q4k[0];
292    const vec2 loadd = vec2(unpackFloat2x16(v.x));
293
294    uint32_t sc;
295    uint32_t mbyte;
296
297    uint32_t scale0 = v.y;
298    uint32_t scale4 = v.z;
299    uint32_t scale8 = v.w;
300
301    uint32_t sc_lo = scale0;
302    uint32_t mb_lo = scale4;
303    uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
304    uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
305
306    sc = is < 4 ? sc_lo : sc_hi;
307    mbyte = is < 4 ? mb_lo : mb_hi;
308    sc = sc >> (8 * (is & 3));
309    mbyte = mbyte >> (8 * (is & 3));
310    sc &= 0x3F;
311    mbyte &= 0x3F;
312
313    const float d = loadd.x * float(sc);
314    const float m = loadd.y * float(mbyte);
315#endif
316
317    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
318    qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
319
320    float ret = d * float(qs) - m;
321
322    return float16_t(ret);
323}
324
325layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
326   block_q5_K block;
327};
328
329layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
330   block_q5_K_packed16 block;
331};
332
333layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
334   block_q5_K_packed128 block;
335};
336
337float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
338{
339    decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
340    decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
341    const uint idx = coordInBlock[1];
342
343    const uint b = (idx & 0x20) >> 5;          // 0,1
344    const uint is = (idx & 0xE0) >> 5;         // 0..7
345
346#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
347    vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
348    float d = v.x;
349    float m = v.y;
350#else
351    uvec4 v = bl128.block.q5k[0];
352
353    const f16vec2 loadd = unpackFloat2x16(v.x);
354
355    uint32_t sc;
356    uint32_t mbyte;
357
358    uint32_t scale0 = v.y;
359    uint32_t scale4 = v.z;
360    uint32_t scale8 = v.w;
361
362    uint32_t sc_lo = scale0;
363    uint32_t mb_lo = scale4;
364    uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
365    uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
366
367    sc = is < 4 ? sc_lo : sc_hi;
368    mbyte = is < 4 ? mb_lo : mb_hi;
369    sc = sc >> (8 * (is & 3));
370    mbyte = mbyte >> (8 * (is & 3));
371    sc &= 0x3F;
372    mbyte &= 0x3F;
373
374    const float16_t d = loadd.x * float16_t(sc);
375    const float16_t m = loadd.y * float16_t(mbyte);
376#endif
377
378    uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
379    qh = ((qh >> is) & 0x101) << 4;
380
381    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
382    qs = (qs >> (b * 4)) & 0x0F0F;
383    qs = unpack8(qs | qh)[idx & 1];
384
385    float ret = d * float(qs) - m;
386
387    return float16_t(ret);
388}
389
390layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
391   block_q6_K block;
392};
393
394layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
395   block_q6_K_packed16 block;
396};
397
398float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
399{
400    decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
401    const uint idx = coordInBlock[1];
402
403    const uint b = (idx & 0x40) >> 6;           // 0,1
404    const uint qhshift = (idx & 0x60) >> 4;    // 0,2,4,6
405    const uint is = (idx & 0xF0) >> 4;          // 0..15
406
407    const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
408
409    uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
410    ql = (ql >> (b * 4)) & 0x0F0F;
411
412    uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
413    qh = ((qh >> qhshift) & 0x0303) << 4;
414
415    int q = unpack8(ql | qh)[idx & 1];
416
417    float16_t ret = dscale * float16_t(q - 32);
418
419    return ret;
420}
421
422#if defined(DATA_A_IQ1_S)
423layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
424   block_iq1_s block;
425};
426
427float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
428{
429    const float16_t d = bl.block.d;
430    const uint idx = coordInBlock[1];
431
432    const uint ib32 = (idx & 0xE0) >> 5;
433    const uint ib8 = (idx & 0xF8) >> 3;
434
435    const uint qh = bl.block.qh[ib32];
436    const uint qs = bl.block.qs[ib8];
437    const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
438    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
439    const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
440
441    float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
442    return ret;
443}
444#endif
445
446#if defined(DATA_A_IQ1_M)
447layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
448   block_iq1_m block;
449};
450
451layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
452   block_iq1_m_packed64 block;
453};
454
455float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
456{
457    decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
458    const uint idx = coordInBlock[1];
459
460    uvec2 scales = unpack32(bl64.block.scales);
461    const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
462
463    const uint ib8 = (idx & 0xF8) >> 3;
464    const uint ib16 = (idx & 0xF0) >> 4;
465    const int i8 = int(idx % 8);
466    const uint sc = bl.block.scales[ib8 / 8];
467    const uint qs = bl.block.qs[ib8];
468    const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
469    const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
470    const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
471    const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
472
473    float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
474    return ret;
475}
476#endif
477
478#if defined(DATA_A_IQ2_XXS)
479layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
480   block_iq2_xxs block;
481};
482
483layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
484   block_iq2_xxs_packed16 block;
485};
486
487float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
488{
489    decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
490    const float16_t d = bl.block.d;
491    const uint idx = coordInBlock[1];
492
493    const uint ib32 = (idx & 0xE0) >> 5; // 0..7
494    const uint ib8 = (idx & 0x18) >> 3;  // 0..3
495    const uint iqs = 8 * ib32 + ib8;
496
497    const uint qs = bl.block.qs[iqs];
498    const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
499
500    const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
501    uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
502    sign |= bitCount(sign) << 7;
503
504    uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
505    g2 >>= (idx & 2) * 8;
506    const vec2 g = vec2(unpack8(g2));
507
508    vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
509    return float16_t(ret[idx & 1]);
510}
511#endif
512
513#if defined(DATA_A_IQ2_XS)
514layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
515   block_iq2_xs block;
516};
517
518float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
519{
520    const float16_t d = bl.block.d;
521    const uint idx = coordInBlock[1];
522
523    const uint is = (idx & 0xE0) >> 5;     // 0..8
524    const uint sshift = (idx & 0x10) >> 2; // 0,4
525    const uint iqs = (idx & 0xF8) >> 3;    // 0..63
526
527    const uint16_t qs = bl.block.qs[iqs];
528    const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
529
530    uint sign = uint(qs >> 9);
531    sign |= bitCount(sign) << 7;
532    uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
533    g2 >>= (idx & 2) * 8;
534    const vec2 g = vec2(unpack8(g2));
535
536    vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
537    return float16_t(ret[idx & 1]);
538}
539#endif
540
541#if defined(DATA_A_IQ2_S)
542layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
543   block_iq2_s block;
544};
545
546float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
547{
548    uint idx = coordInBlock[1];
549
550    const uint ib32 = (idx & 0xE0) >> 5;        // 0..7
551    const uint ib8 = (idx & 0xF8) >> 3;         // 0..31
552    const uint qhshift = 2 * (ib8 % 4);
553
554    const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
555    const uint qs = bl.block.qs[ib8];
556    const uint qh = bl.block.qh[ib32];
557    const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
558
559    const float d = float(bl.block.d);
560    const float db = d * 0.25 * (0.5 + scale);
561    const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
562    uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
563    g2 >>= (idx & 2) * 8;
564    const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
565    return float16_t(v[idx & 1]);
566}
567#endif
568
569#if defined(DATA_A_IQ3_XXS)
570layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
571   block_iq3_xxs block;
572};
573
574layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
575   block_iq3_xxs_packed16 block;
576};
577
578float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
579{
580    decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
581    uint idx = coordInBlock[1];
582
583    const uint iqs = (idx & 0xFC) >> 2;             // 0..63
584    const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
585
586    const float d = float(bl.block.d);
587    const uint qs = bl.block.qs[iqs];
588    const uint signs = pack32(u16vec2(
589        bl16.block.qs[is/2+0],
590        bl16.block.qs[is/2+1]
591    ));
592    const float db = d * 0.5 * (0.5 + (signs >> 28));
593    const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
594    const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
595    const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
596    const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
597    const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
598    return float16_t(v[idx & 1]);
599}
600#endif
601
602#if defined(DATA_A_IQ3_S)
603layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
604   block_iq3_s block;
605};
606
607float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
608{
609    uint idx = coordInBlock[1];
610
611    const uint iqs = (idx & 0xFC) >> 2;           // 0..63
612    const uint iqh = (idx & 0xE0) >> 5;
613
614    const float d = float(bl.block.d);
615    const uint qs = bl.block.qs[iqs];
616    const uint qh = bl.block.qh[iqh];
617    const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
618    const uint scale = bl.block.scales[iqs / 16];
619    const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
620    const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
621    const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
622    const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
623
624    return float16_t(v[idx & 1]);
625}
626#endif
627
628#if defined(DATA_A_IQ4_XS)
629layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
630   block_iq4_xs block;
631};
632
633float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
634{
635    const float16_t d = bl.block.d;
636    const uint idx = coordInBlock[1];
637
638    const uint ib32 = (idx & 0xE0) >> 5; // 0..7
639
640    const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
641    const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
642    const uint qshift = (idx & 16) >> 2;
643    const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
644
645    float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
646    return ret;
647}
648#endif
649
650#if defined(DATA_A_IQ4_NL)
651layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
652   block_iq4_nl block;
653};
654
655float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
656{
657    const float16_t d = bl.block.d;
658    const uint idx = coordInBlock[1];
659    const uint iqs = idx & 0xF;
660    const uint shift = (idx & 0x10) >> 2;
661    uint32_t qs = bl.block.qs[iqs];
662    qs >>= shift;
663    qs &= 0xF;
664    float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
665    return ret;
666}
667#endif
668
669#if defined(DATA_A_MXFP4)
670layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
671   block_mxfp4 block;
672};
673
674float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
675{
676    const float d = e8m0_to_fp32(bl.block.e);
677    const uint idx = coordInBlock[1];
678    const uint iqs = idx & 0xF;
679    const uint shift = (idx & 0x10) >> 2;
680    uint32_t qs = bl.block.qs[iqs];
681    qs >>= shift;
682    qs &= 0xF;
683    float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
684    return ret;
685}
686#endif
687
688#if defined(DATA_A_Q4_0)
689#define dequantFuncA dequantFuncQ4_0
690#elif defined(DATA_A_Q4_1)
691#define dequantFuncA dequantFuncQ4_1
692#elif defined(DATA_A_Q5_0)
693#define dequantFuncA dequantFuncQ5_0
694#elif defined(DATA_A_Q5_1)
695#define dequantFuncA dequantFuncQ5_1
696#elif defined(DATA_A_Q8_0)
697#define dequantFuncA dequantFuncQ8_0
698#elif defined(DATA_A_Q2_K)
699#define dequantFuncA dequantFuncQ2_K
700#elif defined(DATA_A_Q3_K)
701#define dequantFuncA dequantFuncQ3_K
702#elif defined(DATA_A_Q4_K)
703#define dequantFuncA dequantFuncQ4_K
704#define fetch_scales fetch_scalesQ4_K
705#define store_scales store_scalesQ4_K
706#elif defined(DATA_A_Q5_K)
707#define dequantFuncA dequantFuncQ5_K
708#define fetch_scales fetch_scalesQ5_K
709#define store_scales store_scalesQ4_K
710#elif defined(DATA_A_Q6_K)
711#define dequantFuncA dequantFuncQ6_K
712#elif defined(DATA_A_IQ1_S)
713#define dequantFuncA dequantFuncIQ1_S
714#elif defined(DATA_A_IQ1_M)
715#define dequantFuncA dequantFuncIQ1_M
716#elif defined(DATA_A_IQ2_XXS)
717#define dequantFuncA dequantFuncIQ2_XXS
718#elif defined(DATA_A_IQ2_XS)
719#define dequantFuncA dequantFuncIQ2_XS
720#elif defined(DATA_A_IQ2_S)
721#define dequantFuncA dequantFuncIQ2_S
722#elif defined(DATA_A_IQ3_XXS)
723#define dequantFuncA dequantFuncIQ3_XXS
724#elif defined(DATA_A_IQ3_S)
725#define dequantFuncA dequantFuncIQ3_S
726#elif defined(DATA_A_IQ4_XS)
727#define dequantFuncA dequantFuncIQ4_XS
728#elif defined(DATA_A_IQ4_NL)
729#define dequantFuncA dequantFuncIQ4_NL
730#elif defined(DATA_A_MXFP4)
731#define dequantFuncA dequantFuncMXFP4
732#elif defined(DATA_A_F32)
733#define dequantFuncA dequantFuncF32
734#endif