1#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  2#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  3#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  4
  5#include "types.glsl"
  6
  7#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
  8FLOAT_TYPE get_dm(uint ib) {
  9    return FLOAT_TYPE(data_a[ib].d);
 10}
 11#endif
 12
 13#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
 14FLOAT_TYPE_VEC2 get_dm(uint ib) {
 15    return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
 16}
 17#endif
 18
 19#if defined(DATA_A_MXFP4)
 20FLOAT_TYPE get_dm(uint ib) {
 21    return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
 22}
 23#endif
 24
 25#if defined(DATA_A_Q2_K)
 26FLOAT_TYPE_VEC2 get_dm(uint ib) {
 27    const uint ib_k = ib / 8;
 28    return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
 29}
 30#endif
 31
 32// Each iqs value maps to a 32-bit integer
 33#if defined(DATA_A_Q4_0)
 34// 2-byte loads for Q4_0 blocks (18 bytes)
 35i32vec2 repack(uint ib, uint iqs) {
 36    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],
 37                                   data_a_packed16[ib].qs[iqs * 2 + 1]);
 38    const uint32_t vui = pack32(quants);
 39    return i32vec2( vui       & 0x0F0F0F0F,
 40                   (vui >> 4) & 0x0F0F0F0F);
 41}
 42
 43FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
 44    return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
 45}
 46#endif
 47
 48#if defined(DATA_A_Q4_1)
 49// 4-byte loads for Q4_1 blocks (20 bytes)
 50i32vec2 repack(uint ib, uint iqs) {
 51    const uint32_t vui = data_a_packed32[ib].qs[iqs];
 52    return i32vec2( vui       & 0x0F0F0F0F,
 53                   (vui >> 4) & 0x0F0F0F0F);
 54}
 55
 56FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
 57    return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
 58}
 59#endif
 60
 61#if defined(DATA_A_Q5_0)
 62// 2-byte loads for Q5_0 blocks (22 bytes)
 63i32vec2 repack(uint ib, uint iqs) {
 64    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],
 65                                   data_a_packed16[ib].qs[iqs * 2 + 1]);
 66    const uint32_t vui = pack32(quants);
 67    const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
 68    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
 69                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
 70
 71    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
 72                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
 73
 74    return i32vec2(v0, v1);
 75}
 76
 77FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
 78    return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
 79}
 80#endif
 81
 82#if defined(DATA_A_Q5_1)
 83// 4-byte loads for Q5_1 blocks (24 bytes)
 84i32vec2 repack(uint ib, uint iqs) {
 85    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],
 86                                   data_a_packed16[ib].qs[iqs * 2 + 1]);
 87    const uint32_t vui = pack32(quants);
 88    const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
 89    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
 90                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
 91
 92    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
 93                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
 94
 95    return i32vec2(v0, v1);
 96}
 97
 98FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
 99    return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
100}
101#endif
102
103#if defined(DATA_A_Q8_0)
104// 2-byte loads for Q8_0 blocks (34 bytes)
105int32_t repack(uint ib, uint iqs) {
106    return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2    ],
107                          data_a_packed16[ib].qs[iqs * 2 + 1]));
108}
109
110FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
111    return FLOAT_TYPE(float(q_sum) * da * dsb.x);
112}
113#endif
114
115#if defined(DATA_A_MXFP4)
116// 1-byte loads for mxfp4 blocks (17 bytes)
117i32vec2 repack(uint ib, uint iqs) {
118    const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4    ],
119                                      data_a[ib].qs[iqs * 4 + 1],
120                                      data_a[ib].qs[iqs * 4 + 2],
121                                      data_a[ib].qs[iqs * 4 + 3]));
122
123    const u8vec4 i_a0 = unpack8( qs       & 0x0F0F0F0F);
124    const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
125
126    return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
127                   pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
128}
129
130FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
131    return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
132}
133#endif
134
135#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
136FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
137    int32_t q_sum = 0;
138#if QUANT_R == 2
139    const i32vec2 data_a_qs = repack(ib_a, iqs);
140    q_sum += dotPacked4x8EXT(data_a_qs.x,
141                             cache_b_qs[0]);
142    q_sum += dotPacked4x8EXT(data_a_qs.y,
143                             cache_b_qs[1]);
144#else
145    int32_t data_a_qs = repack(ib_a, iqs * 2);
146    q_sum += dotPacked4x8EXT(data_a_qs,
147                             cache_b_qs[0]);
148    data_a_qs = repack(ib_a, iqs * 2 + 1);
149    q_sum += dotPacked4x8EXT(data_a_qs,
150                             cache_b_qs[1]);
151#endif
152
153    // 2 quants per call => divide sums by 8/2 = 4
154    return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
155}
156#endif
157
158#if defined(DATA_A_Q2_K)
159// 4-byte loads for Q2_K blocks (84 bytes)
160i32vec4 repack4(uint ib, uint iqs) {
161    const uint ib_k = ib / 8;
162    const uint iqs_k = (ib % 8) * 8 + iqs;
163
164    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
165    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
166
167    return i32vec4((data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303,
168                   (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
169                   (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
170                   (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
171}
172
173uint8_t get_scale(uint ib, uint iqs) {
174    const uint ib_k = ib / 8;
175    const uint iqs_k = (ib % 8) * 8 + iqs;
176
177    return data_a[ib_k].scales[iqs_k / 4];
178}
179
180FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
181    int32_t sum_d = 0;
182    int32_t sum_m = 0;
183
184    const i32vec4 qs_a = repack4(ib_a, iqs * 4);
185    const uint8_t scale = get_scale(ib_a, iqs * 4);
186    const vec2 dm = vec2(get_dm(ib_a));
187    const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
188
189    sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
190    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
191
192    sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
193    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
194
195    sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
196    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
197
198    sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
199    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
200
201    return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
202}
203#endif
204
205#if defined(DATA_A_Q3_K)
206// 2-byte loads for Q3_K blocks (110 bytes)
207i32vec4 repack4(uint ib, uint iqs) {
208    const uint ib_k = ib / 8;
209    const uint iqs_k = (ib % 8) * 8 + iqs;
210
211    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
212    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
213    const uint hm_shift = iqs_k / 8;
214
215    // bitwise OR to add 4 if hmask is set, subtract later
216    const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2    ] >> qs_shift) & uint16_t(0x0303))) |
217                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2    ] >> hm_shift) & uint16_t(0x0101)) << 2));
218    const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
219                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
220    const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
221                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
222    const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
223                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
224    const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
225                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
226    const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
227                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
228    const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
229                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
230    const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
231                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
232
233    return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
234                   pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
235                   pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
236                   pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
237}
238
239float get_d_scale(uint ib, uint iqs) {
240    const uint ib_k = ib / 8;
241    const uint iqs_k = (ib % 8) * 8 + iqs;
242    const uint is = iqs_k / 4;
243
244    const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8      ] >> (4 * (is / 8))) & 0x0F0F) |
245                               (((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
246    return float(data_a[ib_k].d) * float(scale - 32);
247}
248
249FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
250    int32_t q_sum = 0;
251
252    const i32vec4 qs_a = repack4(ib_a, iqs * 4);
253    const float d_scale = get_d_scale(ib_a, iqs * 4);
254
255    q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
256    q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
257    q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
258    q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
259
260    return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
261}
262#endif
263
264#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
265// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
266i32vec4 repack4(uint ib, uint iqs) {
267    const uint ib_k = ib / 8;
268    const uint iqs_k = (ib % 8) * 8 + iqs;
269
270    const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
271    const uint qs_shift = ((iqs_k % 16) / 8) * 4;
272
273#if defined(DATA_A_Q4_K)
274    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F;
275    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
276    const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
277    const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
278
279    return i32vec4(vals0, vals1, vals2, vals3);
280#else // defined(DATA_A_Q5_K)
281    const uint qh_idx = iqs;
282    const uint qh_shift = iqs_k / 8;
283
284    return i32vec4(((data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F) |
285                  (((data_a_packed32[ib_k].qh[qh_idx    ] >> qh_shift) & 0x01010101) << 4),
286                   ((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
287                  (((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
288                   ((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
289                  (((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
290                   ((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
291                  (((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
292#endif
293}
294
295vec2 get_dm_scale(uint ib, uint iqs) {
296    const uint ib_k = ib / 8;
297    const uint iqs_k = (ib % 8) * 8 + iqs;
298    const uint is = iqs_k / 8;
299    u8vec2 scale_dm;
300    if (is < 4) {
301        scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
302    } else {
303        scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
304                          (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));
305    }
306
307    return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
308}
309
310FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
311    int32_t q_sum = 0;
312
313    const i32vec4 qs_a = repack4(ib_a, iqs * 4);
314    const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
315
316    q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
317    q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
318    q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
319    q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
320
321    return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
322}
323#endif
324
325#if defined(DATA_A_Q6_K)
326// 2-byte loads for Q6_K blocks (210 bytes)
327i32vec4 repack4(uint ib, uint iqs) {
328    const uint ib_k = ib / 8;
329    const uint iqs_k = (ib % 8) * 8 + iqs;
330
331    const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
332    const uint ql_shift = ((iqs_k % 32) / 16) * 4;
333
334    const uint qh_idx = (iqs_k / 32) * 8 + iqs;
335    const uint qh_shift = ((iqs_k % 32) / 8) * 2;
336
337    const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2    ] >> ql_shift) & uint16_t(0x0F0F))) |
338                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2    ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
339    const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
340                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
341    const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
342                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
343    const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
344                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
345    const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
346                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
347    const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
348                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
349    const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
350                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
351    const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
352                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
353
354    return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
355                   pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
356                   pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
357                   pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
358}
359
360float get_d_scale(uint ib, uint iqs) {
361    const uint ib_k = ib / 8;
362    const uint iqs_k = (ib % 8) * 8 + iqs;
363    return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
364}
365
366FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
367    int32_t q_sum = 0;
368
369    const i32vec4 qs_a = repack4(ib_a, iqs * 4);
370    const float d_scale = get_d_scale(ib_a, iqs * 4);
371
372    q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
373    q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
374    q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
375    q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
376
377    return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
378}
379#endif
380
381#if defined(DATA_A_IQ1_S)
382void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {
383    const uint ib32 = iqs / 32;
384
385    const uint qh = data_a[ib].qh[ib32];
386
387    const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];
388    const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];
389
390    const uint qs0 = qs16_0 & 0xFF;
391    const uint qs1 = qs16_0 >> 8;
392    const uint qs2 = qs16_1 & 0xFF;
393    const uint qs3 = qs16_1 >> 8;
394
395    const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);
396    const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);
397    const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);
398    const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);
399
400    const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);
401    const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);
402    const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);
403    const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);
404
405    out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,
406                   (grid0 >> 4) & 0x0F0F0F0F,
407                   (grid1 >> 0) & 0x0F0F0F0F,
408                   (grid1 >> 4) & 0x0F0F0F0F);
409    out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,
410                   (grid2 >> 4) & 0x0F0F0F0F,
411                   (grid3 >> 0) & 0x0F0F0F0F,
412                   (grid3 >> 4) & 0x0F0F0F0F);
413}
414
415vec2 get_dm(uint ib, uint iqs) {
416    const uint ib32 = iqs / 32;
417
418    const uint qh = data_a[ib].qh[ib32];
419    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
420
421    const float d = float(data_a[ib].d);
422    const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
423
424    // the -1 cancels out the bias in iq1s_grid_gpu
425    return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
426}
427
428FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
429    int32_t q_sum = 0;
430
431    const uint ib_k = ib_a / 8;
432    const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
433
434    i32vec4 qs_a0;
435    i32vec4 qs_a1;
436    repack8(ib_k, iqs_k, qs_a0, qs_a1);
437
438    const vec2 dm = get_dm(ib_k, iqs_k);
439
440    q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);
441    q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);
442    q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);
443    q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);
444    q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);
445    q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);
446    q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);
447    q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);
448
449    return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));
450}
451#endif
452
453#if defined(DATA_A_IQ1_M)
454FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
455    const uint ib_k = ib_a / 8;
456    const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
457
458    const uint ib32 = iqs_k / 32;
459    const uint ib64 = ib32 / 2;
460
461    const uint16_t[4] scales = data_a[ib_k].scales;
462    const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
463    const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
464
465    const uint qs32 = data_a_packed32[ib_k].qs[ib32];
466    const uint qh16 = data_a_packed16[ib_k].qh[ib32];
467
468    float sum = 0;
469    const uint sc = data_a[ib_k].scales[ib64];
470    [[unroll]] for (int l = 0; l < 4; ++l) {
471        const uint ib16 = 2 * ib32 + l / 2;
472        const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
473        const uint qh = qh16 >> (4 * l);
474        const uint qs = (qs32 >> (8 * l)) & 0xFF;
475        const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
476
477        const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);
478
479        int32_t q_sum = 0;
480        q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);
481        q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);
482
483        int32_t y_sum = 0;
484        y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);
485        y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);
486
487        // the -1 cancels out the bias in iq1s_grid_gpu
488        sum += dl * (q_sum + y_sum * (delta - 1));
489    }
490    sum *= float(cache_b_ds.x);
491
492    return sum;
493}
494#endif