1#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  2#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  3#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  4
  5#include "types.glsl"
  6
  7// Each iqs value maps to a 32-bit integer
  8
  9#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
 10// 2-byte loads for Q4_0 blocks (18 bytes)
 11// 4-byte loads for Q4_1 blocks (20 bytes)
 12void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
 13#ifdef DATA_A_Q4_0
 14    buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
 15                                           data_a_packed16[ib].qs[iqs * 2 + 1]));
 16
 17    if (iqs == 0) {
 18        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
 19    }
 20#else // DATA_A_Q4_1
 21    buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
 22
 23    if (iqs == 0) {
 24        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
 25    }
 26#endif
 27}
 28
 29void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
 30    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
 31
 32    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
 33        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
 34    }
 35}
 36
 37ACC_TYPE mmq_dot_product(const uint ib_a) {
 38    int32_t q_sum = 0;
 39    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
 40        const uint32_t vui = cache_a[ib_a].qs[iqs];
 41        const i32vec2 qs_a = i32vec2( vui       & 0x0F0F0F0F,
 42                                     (vui >> 4) & 0x0F0F0F0F);
 43
 44        const int32_t qs_b0 = cache_b.qs[iqs];
 45        const int32_t qs_b1 = cache_b.qs[iqs + 4];
 46
 47        q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
 48        q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
 49    }
 50
 51#ifdef DATA_A_Q4_0
 52    return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
 53#else // DATA_A_Q4_1
 54    return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
 55#endif
 56}
 57#endif
 58
 59#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
 60// 2-byte loads for Q5_0 blocks (22 bytes)
 61// 4-byte loads for Q5_1 blocks (24 bytes)
 62void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
 63#ifdef DATA_A_Q5_0
 64    buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
 65                                           data_a_packed16[ib].qs[iqs * 2 + 1]));
 66
 67    if (iqs == 0) {
 68        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
 69        buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
 70    }
 71#else // DATA_A_Q5_1
 72    buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
 73
 74    if (iqs == 0) {
 75        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
 76        buf_a[buf_ib].qh = data_a_packed32[ib].qh;
 77    }
 78#endif
 79}
 80
 81void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
 82    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
 83    cache_a[reg_ib].qh = buf_a[buf_ib].qh;
 84
 85    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
 86        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
 87    }
 88}
 89
 90ACC_TYPE mmq_dot_product(const uint ib_a) {
 91    int32_t q_sum = 0;
 92    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
 93        const uint32_t vui = cache_a[ib_a].qs[iqs];
 94        const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
 95        const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
 96                         | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
 97        const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
 98                         | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
 99
100        const int32_t qs_b0 = cache_b.qs[iqs];
101        const int32_t qs_b1 = cache_b.qs[iqs + 4];
102
103        q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
104        q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
105    }
106
107#ifdef DATA_A_Q5_0
108    return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
109#else // DATA_A_Q5_1
110    return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
111#endif
112}
113#endif
114
115#if defined(DATA_A_Q8_0)
116// 2-byte loads for Q8_0 blocks (34 bytes)
117void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
118    buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
119                                           data_a_packed16[ib].qs[iqs * 2 + 1]));
120
121    if (iqs == 0) {
122        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
123    }
124}
125
126void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
127    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
128
129    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
130        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
131    }
132}
133
134ACC_TYPE mmq_dot_product(const uint ib_a) {
135    int32_t q_sum = 0;
136    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
137        const int32_t qs_a = cache_a[ib_a].qs[iqs];
138        const int32_t qs_b = cache_b.qs[iqs];
139
140        q_sum += dotPacked4x8EXT(qs_a, qs_b);
141    }
142
143    return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
144}
145#endif
146
147#if defined(DATA_A_MXFP4)
148// 1-byte loads for mxfp4 blocks (17 bytes)
149void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
150    const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4    ],
151                                      data_a[ib].qs[iqs * 4 + 1],
152                                      data_a[ib].qs[iqs * 4 + 2],
153                                      data_a[ib].qs[iqs * 4 + 3]));
154
155    const u8vec4 i_a0 = unpack8( qs       & 0x0F0F0F0F);
156    const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
157
158    buf_a[buf_ib].qs[iqs    ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
159    buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
160
161    if (iqs == 0) {
162        buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
163    }
164}
165
166void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
167    cache_a[reg_ib].d = buf_a[buf_ib].d;
168
169    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
170        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
171    }
172}
173
174ACC_TYPE mmq_dot_product(const uint ib_a) {
175    int32_t q_sum = 0;
176    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
177        const int32_t qs_a = cache_a[ib_a].qs[iqs];
178
179        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
180    }
181
182    return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
183}
184#endif
185
186// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
187// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
188#if defined(DATA_A_Q2_K)
189// 4-byte loads for Q2_K blocks (84 bytes)
190void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
191    const uint ib_k = ib / 8;
192    const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
193
194    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
195    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
196
197    // Repack 4x4 quants into one int
198    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303;
199    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
200    const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
201    const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
202
203    buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
204
205    if (iqs == 0) {
206        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
207        buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
208    }
209}
210
211void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
212    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
213    cache_a[reg_ib].scales = buf_a[buf_ib].scales;
214
215    [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
216        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
217    }
218}
219
220ACC_TYPE mmq_dot_product(const uint ib_a) {
221    int32_t sum_d = 0;
222    int32_t sum_m = 0;
223
224    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
225        const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
226        const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
227        const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
228
229        sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
230        sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
231    }
232
233    return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
234}
235#endif
236
237#if defined(DATA_A_Q3_K)
238// 2-byte loads for Q3_K blocks (110 bytes)
239void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
240    const uint ib_k = ib / 8;
241    const uint hm_idx = iqs * QUANT_R_MMQ;
242    const uint iqs_k = (ib % 8) * 8 + hm_idx;
243
244    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
245    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
246    const uint hm_shift = iqs_k / 8;
247
248    // Repack 2x4 quants into one int
249    // Add the 3rd bit instead of subtracting it to allow packing the quants
250    // vec4 for unpack8 used due to #12147
251    const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2        ] >> qs_shift) & uint16_t(0x0303)))).xy |
252                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2    ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
253    const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1    ] >> qs_shift) & uint16_t(0x0303)))).xy |
254                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
255    const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2    ] >> qs_shift) & uint16_t(0x0303)))).xy |
256                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
257    const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3    ] >> qs_shift) & uint16_t(0x0303)))).xy |
258                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
259    buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
260                           (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
261
262    if (iqs == 0) {
263        const uint is = iqs_k / 4;
264        const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8      ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
265                                                     (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
266
267        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
268    }
269}
270
271void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
272    cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
273
274    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
275        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
276    }
277}
278
279ACC_TYPE mmq_dot_product(const uint ib_a) {
280    float result = 0.0;
281    int32_t q_sum = 0;
282
283    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
284        // Subtract 4 from the quants to correct the 3rd bit offset
285        const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
286
287        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
288    }
289    result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
290    q_sum = 0;
291
292    [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
293        const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
294
295        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
296    }
297    result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
298
299    return ACC_TYPE(float(cache_b.ds.x) * result);
300}
301#endif
302
303#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
304// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
305void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
306    const uint ib_k = ib / 8;
307    const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
308
309    const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
310    const uint qs_shift = ((iqs_k % 16) / 8) * 4;
311
312    // Repack 2x4 quants into one int
313#if defined(DATA_A_Q4_K)
314    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F;
315    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
316
317    buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
318#else // defined(DATA_A_Q5_K)
319    const uint qh_idx = iqs * QUANT_R_MMQ;
320    const uint qh_shift = iqs_k / 8;
321
322    buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
323                                   (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
324#endif
325
326    if (iqs == 0) {
327        // Scale index
328        const uint is = iqs_k / 8;
329        u8vec2 scale_dm;
330        if (is < 4) {
331            scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
332        } else {
333            scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
334                              (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));
335        }
336
337        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
338    }
339}
340
341void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
342    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
343
344    [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
345        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
346    }
347}
348
349ACC_TYPE mmq_dot_product(const uint ib_a) {
350    int32_t q_sum = 0;
351
352    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
353#if defined(DATA_A_Q4_K)
354        const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
355#else // defined(DATA_A_Q5_K)
356        const int32_t qs_a = cache_a[ib_a].qs[iqs];
357#endif
358
359        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
360    }
361
362    return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
363}
364#endif
365
366#if defined(DATA_A_Q6_K)
367// 2-byte loads for Q6_K blocks (210 bytes)
368void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
369    const uint ib_k = ib / 8;
370    const uint iqs_k = (ib % 8) * 8 + iqs;
371
372    const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
373    const uint ql_shift = ((iqs_k % 32) / 16) * 4;
374
375    const uint qh_idx = (iqs_k / 32) * 8 + iqs;
376    const uint qh_shift = ((iqs_k % 32) / 8) * 2;
377
378    const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2    ] >> ql_shift) & uint16_t(0x0F0F))).xy |
379                          unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2    ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
380    const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |
381                          unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
382    buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
383
384    if (iqs == 0) {
385        const uint is = iqs_k / 4;
386        const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
387
388        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
389    }
390}
391
392void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
393    cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
394
395    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
396        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
397    }
398}
399
400ACC_TYPE mmq_dot_product(const uint ib_a) {
401    float result = 0.0;
402    int32_t q_sum = 0;
403
404    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
405        const int32_t qs_a = cache_a[ib_a].qs[iqs];
406
407        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
408    }
409    result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
410    q_sum = 0;
411
412    [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
413        const int32_t qs_a = cache_a[ib_a].qs[iqs];
414
415        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
416    }
417    result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
418
419    return ACC_TYPE(float(cache_b.ds.x) * result);
420}
421#endif
422
423void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
424    if (is_in_bounds) {
425        const uint ib_outer = ib / 4;
426        const uint ib_inner = ib % 4;
427
428        if (iqs == 0) {
429            buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
430        }
431
432        const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
433        buf_b[buf_ib].qs[iqs * 4    ] = values.x;
434        buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
435        buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
436        buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
437    } else {
438        if (iqs == 0) {
439            buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
440        }
441
442        buf_b[buf_ib].qs[iqs * 4    ] = 0;
443        buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
444        buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
445        buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
446    }
447}
448
449void block_b_to_registers(const uint ib) {
450    cache_b.ds = buf_b[ib].ds;
451    [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
452        cache_b.qs[iqs] = buf_b[ib].qs[iqs];
453    }
454}