1#define(VARIANTS)
  2
  3[
  4  {
  5    "SHADER_SUFFIX": "f32_vec",
  6    "REPLS": {
  7      "TYPE" : "vec4<f32>",
  8      "DST_TYPE": "vec4<f32>",
  9      "BLOCK_SIZE": 4
 10    },
 11    "DECLS": ["F32_VEC"]
 12  },
 13  {
 14    "REPLS": {
 15      "TYPE" : "f32",
 16      "DST_TYPE": "f32",
 17      "BLOCK_SIZE": 1
 18    },
 19    "DECLS": ["F32"]
 20  },
 21  {
 22    "REPLS": {
 23      "TYPE" : "f16",
 24      "DST_TYPE": "f32",
 25      "BLOCK_SIZE": 1
 26    },
 27    "DECLS": ["F16"]
 28  },
 29  {
 30    "REPLS": {
 31      "TYPE" : "i32",
 32      "DST_TYPE": "i32",
 33      "BLOCK_SIZE": 1
 34    },
 35    "DECLS": ["I32"]
 36  },
 37  {
 38    "REPLS": {
 39      "TYPE" : "q4_0",
 40      "DST_TYPE": "f32",
 41      "BLOCK_SIZE": 32
 42    },
 43    "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
 44  },
 45  {
 46    "REPLS": {
 47      "TYPE" : "q4_1",
 48      "DST_TYPE": "f32",
 49      "BLOCK_SIZE": 32
 50    },
 51    "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
 52  },
 53  {
 54    "REPLS": {
 55      "TYPE" : "q5_0",
 56      "DST_TYPE": "f32",
 57      "BLOCK_SIZE": 32
 58    },
 59    "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
 60  },
 61  {
 62    "REPLS": {
 63      "TYPE" : "q5_1",
 64      "DST_TYPE": "f32",
 65      "BLOCK_SIZE": 32
 66    },
 67    "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
 68  },
 69  {
 70    "REPLS": {
 71      "TYPE" : "q8_0",
 72      "DST_TYPE": "f32",
 73      "BLOCK_SIZE": 32
 74    },
 75    "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
 76  },
 77  {
 78    "REPLS": {
 79      "TYPE" : "q2_k",
 80      "DST_TYPE": "f32",
 81      "BLOCK_SIZE": 256
 82    },
 83    "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
 84  },
 85  {
 86    "REPLS": {
 87      "TYPE" : "q3_k",
 88      "DST_TYPE": "f32",
 89      "BLOCK_SIZE": 256
 90    },
 91    "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
 92  },
 93  {
 94    "REPLS": {
 95      "TYPE" : "q4_k",
 96      "DST_TYPE": "f32",
 97      "BLOCK_SIZE": 256
 98    },
 99    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
100  },
101  {
102    "REPLS": {
103      "TYPE" : "q5_k",
104      "DST_TYPE": "f32",
105      "BLOCK_SIZE": 256
106    },
107    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
108  },
109  {
110    "REPLS": {
111      "TYPE" : "q6_k",
112      "DST_TYPE": "f32",
113      "BLOCK_SIZE": 256
114    },
115    "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
116  },
117  {
118    "REPLS": {
119      "TYPE" : "iq2_xxs",
120      "DST_TYPE": "f32",
121      "BLOCK_SIZE": 256
122    },
123    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
124  },
125  {
126    "REPLS": {
127      "TYPE" : "iq2_xs",
128      "DST_TYPE": "f32",
129      "BLOCK_SIZE": 256
130    },
131    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
132  },
133  {
134    "REPLS": {
135      "TYPE": "iq2_s",
136      "DST_TYPE": "f32",
137      "BLOCK_SIZE": 256
138    },
139    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
140  },
141  {
142    "REPLS": {
143      "TYPE": "iq3_xxs",
144      "DST_TYPE": "f32",
145      "BLOCK_SIZE": 256
146    },
147    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
148  },
149  {
150    "REPLS": {
151      "TYPE": "iq3_s",
152      "DST_TYPE": "f32",
153      "BLOCK_SIZE": 256
154    },
155    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
156  },
157  {
158    "REPLS": {
159      "TYPE": "iq1_s",
160      "DST_TYPE": "f32",
161      "BLOCK_SIZE": 256
162    },
163    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
164  },
165  {
166    "REPLS": {
167      "TYPE": "iq1_m",
168      "DST_TYPE": "f32",
169      "BLOCK_SIZE": 256
170    },
171    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
172  },
173  {
174    "REPLS": {
175      "TYPE": "iq4_nl",
176      "DST_TYPE": "f32",
177      "BLOCK_SIZE": 32,
178    },
179    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
180  },
181  {
182    "REPLS": {
183      "TYPE": "iq4_xs",
184      "DST_TYPE": "f32",
185      "BLOCK_SIZE": 256,
186    },
187    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
188  }
189]
190
191#end(VARIANTS)
192
193#define(DECLS)
194
195#decl(F32_VEC)
196fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
197    dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
198}
199#enddecl(F32_VEC)
200
201#decl(F32)
202fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
203    dst[dst_base + offset] = src[src_base + offset];
204}
205#enddecl(F32)
206
207#decl(F16)
208fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
209    dst[dst_base + offset] = f32(src[src_base + offset]);
210}
211#enddecl(F16)
212
213#decl(I32)
214fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
215    dst[dst_base + offset] = src[src_base + offset];
216}
217#enddecl(I32)
218
219#decl(Q4_0)
220fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
221    let block_q4_0 = src[src_base + offset];
222    let d = f32(block_q4_0.d);
223    for (var j: u32 = 0; j < 4; j++) {
224        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
225        for (var k: u32 = 0; k < 4; k++) {
226            let q_byte = get_byte(q_packed, k);
227            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
228            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
229            let dst_offset = dst_base + offset * 32 + j * 4 + k;
230            dst[dst_offset] = q_lo;
231            dst[dst_offset + 16] = q_hi;
232        }
233    }
234}
235#enddecl(Q4_0)
236
237#decl(Q4_1)
238fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
239    let block_q4_1 = src[src_base + offset];
240    let d = f32(block_q4_1.d);
241    let m = f32(block_q4_1.m);
242    for (var j: u32 = 0; j < 4; j++) {
243        let q_packed = block_q4_1.qs[j];
244        for (var k: u32 = 0; k < 4; k++) {
245            let q_byte = get_byte(q_packed, k);
246            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
247            let q_lo = f32(q_byte & 0xF) * d + m;
248            let dst_offset = dst_base + offset * 32 + j * 4 + k;
249            dst[dst_offset] = q_lo;
250            dst[dst_offset + 16] = q_hi;
251        }
252    }
253}
254#enddecl(Q4_1)
255
256#decl(Q5_0)
257fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
258    let block_q5_0 = src[src_base + offset];
259    let d = f32(block_q5_0.d);
260    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
261    for (var j: u32 = 0; j < 4; j++) {
262        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
263        for (var k: u32 = 0; k < 4; k++) {
264            let q_byte = get_byte(q_packed, k);
265            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
266            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
267            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
268            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
269            let dst_offset = dst_base + offset * 32 + j * 4 + k;
270            dst[dst_offset] = q_lo;
271            dst[dst_offset + 16] = q_hi;
272        }
273    }
274}
275
276#enddecl(Q5_0)
277
278#decl(Q5_1)
279fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
280    let block_q5_1 = src[src_base + offset];
281    let d = f32(block_q5_1.d);
282    let m = f32(block_q5_1.m);
283    for (var j: u32 = 0; j < 4; j++) {
284        let q_packed = block_q5_1.qs[j];
285        for (var k: u32 = 0; k < 4; k++) {
286            let q_byte = get_byte(q_packed, k);
287            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
288            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
289            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
290            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
291            let dst_offset = dst_base + offset * 32 + j * 4 + k;
292            dst[dst_offset] = q_lo;
293            dst[dst_offset + 16] = q_hi;
294        }
295    }
296}
297#enddecl(Q5_1)
298
299#decl(Q8_0)
300fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
301    let block_q8_0 = src[src_base + offset];
302    let d = f32(block_q8_0.d);
303    for (var j: u32 = 0; j < 8; j++) {
304        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
305        for (var k: u32 = 0; k < 4; k++) {
306            let q_byte = get_byte_i32(q_packed, k);
307            let q_val = f32(q_byte) * d;
308            let dst_offset = dst_base + offset * 32 + j * 4 + k;
309            dst[dst_offset] = q_val;
310        }
311    }
312}
313#enddecl(Q8_0)
314
315#decl(Q2_K)
316fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
317    let block = src[src_base + offset];
318    let d = f32(block.d);
319    let m = f32(block.dmin);
320    var dst_i = dst_base + offset * 256;
321    var is: u32 = 0;
322    // 2 halves of the block (128 elements each)
323    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
324        // 4 groups (each group has 2 blocks of 16 elements)
325        for (var shift: u32 = 0; shift < 8; shift += 2) {
326            // 2 blocks
327            for (var k: u32 = 0; k < 32; k += 16) {
328                let sc = get_byte(block.scales[is / 4], is % 4);
329                is++;
330                let dl = d * f32(sc & 0xF);
331                let ml = m * f32(sc >> 4);
332                for (var l: u32 = 0u; l < 16; l++) {
333                    let q_idx = q_b_idx + k + l;
334                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
335                    let qs_val = (q_byte >> shift) & 3;
336                    dst[dst_i] = (f32(qs_val) * dl - ml);
337                    dst_i++;
338                }
339            }
340        }
341    }
342}
343#enddecl(Q2_K)
344
345#decl(Q3_K)
346fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
347    let block = src[src_base + offset];
348    let d = f32(block.d);
349
350    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
351    // and 2-bits from the last 4 bytes
352    let kmask1: u32 = 0x03030303;
353    let kmask2: u32 = 0x0f0f0f0f;
354    var scale_vals: array<u32, 4>;
355    for (var i: u32 = 0; i < 4; i++) {
356        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
357    }
358    var tmp: u32 = scale_vals[2];
359    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
360    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
361    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
362    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
363
364    // convert arrays of f16 -> u32
365    var hmask_vals: array<u32, 8>;
366    for (var i: u32 = 0; i < 8; i++) {
367        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
368    }
369    var qs_vals: array<u32, 16>;
370    for (var i: u32 = 0; i < 16; i++) {
371        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
372    }
373
374    var dst_i = dst_base + offset * 256;
375    var is: u32 = 0;
376    var m: u32 = 1;
377    // 2 halves of the block (128 elements each)
378    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
379        // 4 groups (each group has 2 blocks of 16 elements)
380        for (var shift: u32 = 0; shift < 8; shift += 2) {
381            // 2 blocks
382            for (var k: u32 = 0; k < 32; k += 16) {
383                let sc = get_byte(scale_vals[is / 4], is % 4);
384                is++;
385                let dl = d * (f32(sc) - 32.0);
386                for (var l: u32 = 0u; l < 16u; l++) {
387                    let q_idx = q_b_idx + k + l;
388                    let hm_idx = k + l;
389                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
390                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
391                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
392                    let qs_val = (q_byte >> shift) & 3;
393                    dst[dst_i] = (f32(qs_val) - hm) * dl;
394                    dst_i++;
395                }
396            }
397            m <<= 1;
398        }
399    }
400}
401#enddecl(Q3_K)
402
403#decl(Q4_K)
404// 8 blocks of 32 elements each
405fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
406    let block = src[src_base + offset];
407    let d = f32(block.d);
408    let m = f32(block.dmin);
409    var dst_i = dst_base + offset * 256;
410    var is: u32 = 0;
411    // 2 blocks each iteration
412    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
413        for (var shift: u32 = 0; shift < 8; shift += 4) {
414            let scale_min = get_scale_min(is, block.scales);
415            is++;
416            let dl = d * scale_min.x;
417            let ml = m * scale_min.y;
418            for (var l: u32 = 0; l < 32; l++) {
419                let q_idx = q_b_idx + l;
420                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
421                let qs_val = (q_byte >> shift) & 0xF;
422                dst[dst_i] = (f32(qs_val) * dl - ml);
423                dst_i++;
424            }
425        }
426    }
427}
428#enddecl(Q4_K)
429
430#decl(Q5_K)
431fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
432    let block = src[src_base + offset];
433    let d = f32(block.d);
434    let m = f32(block.dmin);
435    var dst_i = dst_base + offset * 256;
436    var is: u32 = 0;
437    var u: u32 = 1;
438    // 2 blocks each iteration
439    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
440        for (var shift: u32 = 0; shift < 8; shift += 4) {
441            let scale_min = get_scale_min(is, block.scales);
442            is++;
443            let dl = d * scale_min.x;
444            let ml = m * scale_min.y;
445            for (var l: u32 = 0; l < 32; l++) {
446                let q_idx = q_b_idx + l;
447                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
448                let qh_byte = get_byte(block.qh[l / 4], l % 4);
449                let qs_val = (q_byte >> shift) & 0xF;
450                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
451                dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml;
452                dst_i++;
453            }
454            u <<= 1;
455        }
456    }
457}
458#enddecl(Q5_K)
459
460#decl(Q6_K)
461// 16 blocks of 16 elements each
462fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
463    let block = src[src_base + offset];
464    let d = f32(block.d);
465
466    // convert arrays of f16 -> u32
467    var ql_vals: array<u32, 32>;
468    for (var i: u32 = 0; i < 32; i++) {
469        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
470    }
471    var qh_vals: array<u32, 16>;
472    for (var i: u32 = 0; i < 16; i++) {
473        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
474    }
475    var scale_vals: array<u32, 4>;
476    for (var i: u32 = 0; i < 4; i++) {
477        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
478    }
479
480    var dst_i = dst_base + offset * 256;
481    var qh_b_idx: u32 = 0;
482    var sc_b_idx: u32 = 0;
483    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
484        for (var l: u32 = 0; l < 32; l++) {
485            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
486            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
487            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
488
489            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
490            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
491            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
492            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
493
494            let is = l/16;
495            let is1 = sc_b_idx + is;
496            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
497            let is2 = sc_b_idx + is + 2;
498            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
499            let is3 = sc_b_idx + is + 4;
500            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
501            let is4 = sc_b_idx + is + 6;
502            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
503
504            dst[dst_i + l] = (q1 * f32(sc1)) * d;
505            dst[dst_i + l + 32] = (q2 * f32(sc2)) * d;
506            dst[dst_i + l + 64] = (q3 * f32(sc3)) * d;
507            dst[dst_i + l + 96] = (q4 * f32(sc4)) * d;
508        }
509        dst_i += 128;
510        qh_b_idx += 32;
511        sc_b_idx += 8;
512    }
513}
514
515#enddecl(Q6_K)
516
517#decl(IQ2_XXS)
518fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
519    let block = src[src_base + offset];
520    let d = f32(block.d);
521    var dst_i = dst_base + offset * 256;
522    for (var ib: u32 = 0; ib < 32; ib += 4) {
523        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
524        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
525        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
526        for (var l: u32 = 0; l < 4; l++) {
527            let ig = get_byte(aux0, l) * 8;
528            let is = (aux1 >> (7 * l)) & 127;
529            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
530            for (var j: u32 = 0; j < 8; j++) {
531                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
532                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
533                dst[dst_i] = db * f32(g) * m;
534                dst_i++;
535            }
536        }
537    }
538}
539#enddecl(IQ2_XXS)
540
541#decl(IQ2_XS)
542fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
543    let block = src[src_base + offset];
544    let d = f32(block.d);
545    var dst_i = dst_base + offset * 256;
546    var scale_vals = array<u32, 2>(
547        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
548        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
549    );
550    for (var ib: u32 = 0; ib < 32; ib += 4) {
551        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
552        let db = array<f32, 2>(
553            d * (0.5 + f32(s & 0xF)) * 0.25,
554            d * (0.5 + f32(s >> 4)) * 0.25
555        );
556        for (var l: u32 = 0; l < 4; l++) {
557            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
558            let ig = (qs_val & 511) * 8;
559            let is = qs_val >> 9;
560            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
561            let dl = db[l/2];
562            for (var j: u32 = 0; j < 8; j++) {
563                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
564                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
565                dst[dst_i] = dl * f32(g) * m;
566                dst_i++;
567            }
568        }
569    }
570}
571#enddecl(IQ2_XS)
572
573#decl(IQ2_S)
574fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
575    let block = src[src_base + offset];
576    let d = f32(block.d);
577    var dst_i = dst_base + offset * 256;
578    var qs_vals : array<u32, 16>;
579    for (var i: u32 = 0; i < 16; i++) {
580        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
581    }
582    var qh_vals = array<u32, 2>(
583        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
584        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
585    );
586    var scale_vals = array<u32, 2>(
587        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
588        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
589    );
590    for (var ib: u32 = 0; ib < 8; ib ++) {
591        let s = get_byte(scale_vals[ib / 4], ib % 4);
592        let db = array<f32, 2>(
593            d * (0.5 + f32(s & 0xF)) * 0.25,
594            d * (0.5 + f32(s >> 4)) * 0.25
595        );
596        let qs_w = qs_vals[ib];
597        for (var l: u32 = 0; l < 4; l++) {
598            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
599            let ig = (get_byte(qs_w, l) | qh_b) * 8;
600            let signs = get_byte(qs_vals[ib + 8], l);
601            let dl = db[l/2];
602            for (var j: u32 = 0; j < 8; j++) {
603                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
604                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
605                dst[dst_i] = dl * f32(g) * m;
606                dst_i++;
607            }
608        }
609    }
610}
611
612#enddecl(IQ2_S)
613
614#decl(IQ3_XSS)
615fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
616    let block = src[src_base + offset];
617    let d = f32(block.d);
618    var dst_i = dst_base + offset * 256;
619    for (var ib: u32 = 0; ib < 16; ib += 2) {
620        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
621        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
622        for (var l: u32 = 0; l < 4; l++) {
623            let is = (sc_sign >> (7 * l)) & 127;
624            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
625            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
626            let ig1 = get_byte(ig_val, 0);
627            let ig2 = get_byte(ig_val, 1);
628            for (var j: u32 = 0; j < 4; j++) {
629                let g1 = get_byte(iq3xxs_grid[ig1], j);
630                let g2 = get_byte(iq3xxs_grid[ig2], j);
631                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
632                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
633                dst[dst_i] = db * f32(g1) * m1;
634                dst[dst_i + 4] = db * f32(g2) * m2;
635                dst_i++;
636            }
637            dst_i += 4;
638        }
639    }
640}
641#enddecl(IQ3_XSS)
642
643#decl(IQ3_S)
644fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
645    let block = src[src_base + offset];
646    let d = f32(block.d);
647    var dst_i = dst_base + offset * 256;
648    var qh_vals = array<u32, 2>(
649        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
650        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
651    );
652    var sign_vals: array<u32, 8>;
653    for (var i: u32 = 0; i < 8; i++) {
654        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
655    }
656    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
657    for (var ib: u32 = 0; ib < 4; ib++) {
658        let s = get_byte(scale_vals, ib);
659        let db = array<f32, 2>(
660            d * (1.0 + 2.0 * f32(s & 0xF)),
661            d * (1.0 + 2.0 * f32(s >> 4))
662        );
663        for (var k: u32 = 0; k < 2; k++) {
664            let dl = db[k];
665            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
666            let sign_w = sign_vals[ib * 2 + k];
667            for (var l: u32 = 0; l < 4; l++) {
668                let signs = get_byte(sign_w, l);
669                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
670                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
671                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
672                for (var j: u32 = 0; j < 4; j++) {
673                    let g1 = get_byte(iq3s_grid[ig1], j);
674                    let g2 = get_byte(iq3s_grid[ig2], j);
675                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
676                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
677                    dst[dst_i] = dl * f32(g1) * m1;
678                    dst[dst_i + 4] = dl * f32(g2) * m2;
679                    dst_i++;
680                }
681                dst_i += 4;
682            }
683        }
684    }
685}
686#enddecl(IQ3_S)
687
688#decl(IQ1_S)
689fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
690    let block = src[src_base + offset];
691    let d = f32(block.d);
692    var dst_i = dst_base + offset * 256;
693    for (var ib: u32 = 0; ib < 8; ib++) {
694        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
695        let dl = d * (2 * f32((qh >> 12) & 7) + 1);
696        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
697        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
698        for (var l: u32 = 0; l < 4; l++) {
699            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
700            for (var j: u32 = 0; j < 8; j++) {
701                let gw = iq1_grid[(ig + j) / 16];
702                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
703                let gs = bitcast<i32>(g << 30) >> 30;
704                dst[dst_i] = dl * (f32(gs) + delta);
705                dst_i++;
706            }
707        }
708    }
709}
710
711#enddecl(IQ1_S)
712
713#decl(IQ1_M)
714fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
715    let block = src[src_base + offset];
716
717    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
718    let d = f32(bitcast<vec2<f16>>(scale).x);
719    var dst_i = dst_base + offset * 256;
720    for (var ib: u32 = 0; ib < 8; ib++) {
721        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
722        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
723        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
724        var dl = array<f32, 2>(
725            d * f32(2 * s1 + 1),
726            d * f32(2 * s2 + 1)
727        );
728
729        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
730        var idx = array<u32, 4>(
731            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
732            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
733            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
734            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
735        );
736        var delta = array<f32, 4>(
737            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
738            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
739            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
740            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
741        );
742        for (var l: u32 = 0; l < 4; l++) {
743            let ig = idx[l] * 8;
744            for (var j: u32 = 0; j < 8; j++) {
745                let gw = iq1_grid[(ig + j) / 16];
746                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
747                let gs = bitcast<i32>(g << 30) >> 30;
748                dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]);
749                dst_i++;
750            }
751        }
752    }
753}
754
755#enddecl(IQ1_M)
756
757#decl(IQ4_NL)
758fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
759    let block = src[src_base + offset];
760    let d = f32(block.d);
761    var dst_i = dst_base + offset * 32;
762    var qs: array<u32, 4>;
763    for (var i: u32 = 0; i < 4; i++) {
764        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
765    }
766    for (var j: u32 = 0; j < 16; j++) {
767        let qsb = get_byte(qs[j / 4], j % 4);
768        dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]);
769        dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]);
770        dst_i++;
771    }
772}
773#enddecl(IQ4_NL)
774
775#decl(IQ4_XS)
776fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
777    let block = src[src_base + offset];
778    let d = f32(block.d);
779    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
780    var dst_i = dst_base + offset * 256;
781    for (var ib: u32 = 0; ib < 8; ib++) {
782        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
783        let dl = d * (f32(ls) - 32.0);
784        for (var j: u32 = 0; j < 16; j++) {
785            let iqs = ib * 16 + j;
786            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
787            dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]);
788            dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]);
789            dst_i++;
790        }
791        dst_i += 16;
792    }
793}
794#enddecl(IQ4_XS)
795
796#end(DECLS)
797
798#define(SHADER)
799
800enable f16;
801
802DECLS
803
804@group(0) @binding(0)
805var<storage, read_write> src: array<{{TYPE}}>;
806
807@group(0) @binding(1)
808var<storage, read_write> idx: array<i32>;
809
810@group(0) @binding(2)
811var<storage, read_write> dst: array<{{DST_TYPE}}>;
812
813struct Params {
814    offset_src: u32, // in elements
815    offset_idx: u32, // in elements
816    offset_dst: u32, // in elements
817
818    // Strides (in elements)
819    stride_src1: u32,
820    stride_src2: u32,
821    stride_src3: u32,
822
823    stride_idx0: u32,
824    stride_idx1: u32,
825    stride_idx2: u32,
826
827    stride_dst1: u32,
828    stride_dst2: u32,
829    stride_dst3: u32,
830
831    // Shape of dst
832    ne0: u32,
833    n_rows: u32,
834    ne2: u32,
835    ne3: u32,
836
837    // Shape of idx
838    idx1: u32,
839    idx2: u32,
840};
841
842@group(0) @binding(3)
843var<uniform> params: Params;
844
845override wg_size: u32;
846@compute @workgroup_size(wg_size)
847fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
848    if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
849        return;
850    }
851    var i = gid.x;
852    let i_dst3 = i / (params.ne2 * params.n_rows);
853
854    i = i % (params.ne2 * params.n_rows);
855    let i_dst2 = i / params.n_rows;
856    let i_dst1 = i % params.n_rows;
857
858    let i_idx2 = i_dst3 % params.idx2;
859    let i_idx1 = i_dst2 % params.idx1;
860    let i_idx0 = i_dst1;
861
862    let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
863
864    let idx_val = u32(idx[i_idx]);
865
866    let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
867    let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
868
869    for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
870      copy_elements(i_src_row, i_dst_row, i);
871    }
872}
873
874#end(SHADER)