1#define(VARIANTS)
  2
  3[
  4  {
  5    "REPLS": {
  6      "SRC0_TYPE" : "f32",
  7      "SRC1_TYPE" : "f32",
  8      "BLOCK_SIZE" : 1
  9    },
 10    "DECLS" : ["FLOAT"]
 11  },
 12  {
 13    "REPLS": {
 14      "SRC0_TYPE" : "f16",
 15      "SRC1_TYPE" : "f16",
 16      "BLOCK_SIZE" : 1
 17    },
 18    "DECLS" : ["FLOAT"]
 19  },
 20  {
 21    "REPLS": {
 22      "SRC0_TYPE" : "f16",
 23      "SRC1_TYPE" : "f32",
 24      "BLOCK_SIZE" : 1
 25    },
 26    "DECLS" : ["FLOAT"]
 27  },
 28  {
 29    "REPLS": {
 30      "SRC0_TYPE": "q4_0",
 31      "SRC1_TYPE": "f32",
 32      "BLOCK_SIZE": 32
 33    },
 34    "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
 35  },
 36  {
 37    "REPLS": {
 38      "SRC0_TYPE": "q4_1",
 39      "SRC1_TYPE": "f32",
 40      "BLOCK_SIZE": 32
 41    },
 42    "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
 43  },
 44  {
 45    "REPLS": {
 46      "SRC0_TYPE": "q5_0",
 47      "SRC1_TYPE": "f32",
 48      "BLOCK_SIZE": 32
 49    },
 50    "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
 51  },
 52  {
 53    "REPLS": {
 54      "SRC0_TYPE": "q5_1",
 55      "SRC1_TYPE": "f32",
 56      "BLOCK_SIZE": 32
 57    },
 58    "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
 59  },
 60  {
 61    "REPLS": {
 62      "SRC0_TYPE": "q8_0",
 63      "SRC1_TYPE": "f32",
 64      "BLOCK_SIZE": 32
 65    },
 66    "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
 67  },
 68  {
 69    "REPLS": {
 70      "SRC0_TYPE": "q2_k",
 71      "SRC1_TYPE": "f32",
 72      "BLOCK_SIZE": 256
 73    },
 74    "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
 75  },
 76  {
 77    "REPLS": {
 78      "SRC0_TYPE": "q3_k",
 79      "SRC1_TYPE": "f32",
 80      "BLOCK_SIZE": 256
 81    },
 82    "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
 83  },
 84  {
 85    "REPLS": {
 86      "SRC0_TYPE": "q4_k",
 87      "SRC1_TYPE": "f32",
 88      "BLOCK_SIZE": 256
 89    },
 90    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
 91  },
 92  {
 93    "REPLS": {
 94      "SRC0_TYPE": "q5_k",
 95      "SRC1_TYPE": "f32",
 96      "BLOCK_SIZE": 256
 97    },
 98    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
 99  },
100  {
101    "REPLS": {
102      "SRC0_TYPE": "q6_k",
103      "SRC1_TYPE": "f32",
104      "BLOCK_SIZE": 256
105    },
106    "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
107  },
108  {
109    "REPLS": {
110      "SRC0_TYPE": "iq2_xxs",
111      "SRC1_TYPE": "f32",
112      "BLOCK_SIZE": 256
113    },
114    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
115  },
116  {
117    "REPLS": {
118      "SRC0_TYPE": "iq2_xs",
119      "SRC1_TYPE": "f32",
120      "BLOCK_SIZE": 256
121    },
122    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
123  },
124  {
125    "REPLS": {
126      "SRC0_TYPE": "iq2_s",
127      "SRC1_TYPE": "f32",
128      "BLOCK_SIZE": 256
129    },
130    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
131  },
132  {
133    "REPLS": {
134      "SRC0_TYPE": "iq3_xxs",
135      "SRC1_TYPE": "f32",
136      "BLOCK_SIZE": 256
137    },
138    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
139  },
140  {
141    "REPLS": {
142      "SRC0_TYPE": "iq3_s",
143      "SRC1_TYPE": "f32",
144      "BLOCK_SIZE": 256
145    },
146    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
147  },
148  {
149    "REPLS": {
150      "SRC0_TYPE": "iq1_s",
151      "SRC1_TYPE": "f32",
152      "BLOCK_SIZE": 256
153    },
154    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
155  },
156  {
157    "REPLS": {
158      "SRC0_TYPE": "iq1_m",
159      "SRC1_TYPE": "f32",
160      "BLOCK_SIZE": 256
161    },
162    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
163  },
164  {
165    "REPLS": {
166      "SRC0_TYPE": "iq4_nl",
167      "SRC1_TYPE": "f32",
168      "BLOCK_SIZE": 32,
169    },
170    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
171  },
172  {
173    "REPLS": {
174      "SRC0_TYPE": "iq4_xs",
175      "SRC1_TYPE": "f32",
176      "BLOCK_SIZE": 256,
177    },
178    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
179  }
180]
181
182#end(VARIANTS)
183
184#define(DECLS)
185
186#decl(FLOAT)
187fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
188    return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
189}
190#enddecl(FLOAT)
191
192#decl(Q4_0)
193fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
194    let block_q4_0 = src0[src0_idx_base + offset];
195    let d = f32(block_q4_0.d);
196    var sum: f32 = 0.0;
197    for (var j: u32 = 0; j < 4; j++) {
198        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
199        for (var k: u32 = 0; k < 4; k++) {
200            let q_byte = get_byte(q_packed, k);
201            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
202            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
203            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
204            sum += q_lo * f32(src1[src1_offset]);
205            sum += q_hi * f32(src1[src1_offset + 16]);
206        }
207    }
208    return sum;
209}
210#enddecl(Q4_0)
211
212#decl(Q4_1)
213fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
214    let block_q4_1 = src0[src0_idx_base + offset];
215    let d = f32(block_q4_1.d);
216    let m = f32(block_q4_1.m);
217    var sum: f32 = 0.0;
218    for (var j: u32 = 0; j < 4; j++) {
219        let q_packed = block_q4_1.qs[j];
220        for (var k: u32 = 0; k < 4; k++) {
221            let q_byte = get_byte(q_packed, k);
222            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
223            let q_lo = f32(q_byte & 0xF) * d + m;
224            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
225            sum += q_lo * f32(src1[src1_offset]);
226            sum += q_hi * f32(src1[src1_offset + 16]);
227        }
228    }
229    return sum;
230}
231#enddecl(Q4_1)
232
233#decl(Q5_0)
234fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
235    let block_q5_0 = src0[src0_idx_base + offset];
236    let d = f32(block_q5_0.d);
237    var sum: f32 = 0.0;
238    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
239    for (var j: u32 = 0; j < 4; j++) {
240        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
241        for (var k: u32 = 0; k < 4; k++) {
242            let q_byte = get_byte(q_packed, k);
243            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
244            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
245            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
246            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
247            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
248            sum += q_lo * f32(src1[src1_offset]);
249            sum += q_hi * f32(src1[src1_offset + 16]);
250        }
251    }
252    return sum;
253}
254#enddecl(Q5_0)
255
256#decl(Q5_1)
257fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
258    let block_q5_1 = src0[src0_idx_base + offset];
259    let d = f32(block_q5_1.d);
260    let m = f32(block_q5_1.m);
261    var sum: f32 = 0.0;
262    for (var j: u32 = 0; j < 4; j++) {
263        let q_packed = block_q5_1.qs[j];
264        for (var k: u32 = 0; k < 4; k++) {
265            let q_byte = get_byte(q_packed, k);
266            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
267            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
268            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
269            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
270            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
271            sum += q_lo * f32(src1[src1_offset]);
272            sum += q_hi * f32(src1[src1_offset + 16]);
273        }
274    }
275    return sum;
276}
277#enddecl(Q5_1)
278
279#decl(Q8_0)
280fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
281    let block_q8_0 = src0[src0_idx_base + offset];
282    let d = f32(block_q8_0.d);
283    var sum: f32 = 0.0;
284    for (var j: u32 = 0; j < 8; j++) {
285        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
286        for (var k: u32 = 0; k < 4; k++) {
287            let q_byte = get_byte_i32(q_packed, k);
288            let q_val = f32(q_byte) * d;
289            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
290            sum += q_val * f32(src1[src1_offset]);
291        }
292    }
293    return sum;
294}
295#enddecl(Q8_0)
296
297#decl(Q8_1)
298fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
299    let block_q8_1 = src0[src0_idx_base + offset];
300    let d = f32(block_q8_1.d);
301    let m = f32(block_q8_1.m);
302    var sum: f32 = 0.0;
303    for (var j: u32 = 0; j < 8; j++) {
304        let q_packed = block_q8_1.qs[j];
305        for (var k: u32 = 0; k < 4; k++) {
306            let q_byte = get_byte_i32(q_packed, k);
307            let q_val = f32(q_byte) * d + m;
308            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
309            sum += q_val * f32(src1[src1_offset]);
310        }
311    }
312    return sum;
313}
314#enddecl(Q8_1)
315
316#decl(Q2_K)
317// 16 blocks of 16 elements each
318fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
319    let block = src0[src0_idx_base + offset];
320    let d = f32(block.d);
321    let m = f32(block.dmin);
322    var sum = 0.0;
323    var src1_i = src1_idx_base + offset * 256;
324    var is: u32 = 0;
325    // 2 halves of the block (128 elements each)
326    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
327        // 4 groups (each group has 2 blocks of 16 elements)
328        for (var shift: u32 = 0; shift < 8; shift += 2) {
329            // 2 blocks
330            for (var k: u32 = 0; k < 32; k += 16) {
331                let sc = get_byte(block.scales[is / 4], is % 4);
332                is++;
333                let dl = d * f32(sc & 0xF);
334                let ml = m * f32(sc >> 4);
335                for (var l: u32 = 0u; l < 16; l++) {
336                    let q_idx = q_b_idx + k + l;
337                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
338                    let qs_val = (q_byte >> shift) & 3;
339                    sum += (f32(qs_val) * dl - ml) * src1[src1_i];
340                    src1_i++;
341                }
342            }
343        }
344    }
345    return sum;
346}
347
348#enddecl(Q2_K)
349
350#decl(Q3_K)
351// 16 blocks of 16 elements each
352fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
353    let block = src0[src0_idx_base + offset];
354    let d = f32(block.d);
355
356    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
357    // and 2-bits from the last 4 bytes
358    let kmask1: u32 = 0x03030303;
359    let kmask2: u32 = 0x0f0f0f0f;
360    var scale_vals: array<u32, 4>;
361    for (var i: u32 = 0; i < 4; i++) {
362        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
363    }
364    var tmp: u32 = scale_vals[2];
365    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
366    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
367    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
368    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
369
370    // convert arrays of f16 -> u32
371    var hmask_vals: array<u32, 8>;
372    for (var i: u32 = 0; i < 8; i++) {
373        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
374    }
375    var qs_vals: array<u32, 16>;
376    for (var i: u32 = 0; i < 16; i++) {
377        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
378    }
379
380    var sum = 0.0;
381    var src1_i = src1_idx_base + offset * 256;
382    var is: u32 = 0;
383    var m: u32 = 1;
384    // 2 halves of the block (128 elements each)
385    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
386        // 4 groups (each group has 2 blocks of 16 elements)
387        for (var shift: u32 = 0; shift < 8; shift += 2) {
388            // 2 blocks
389            for (var k: u32 = 0; k < 32; k += 16) {
390                let sc = get_byte(scale_vals[is / 4], is % 4);
391                is++;
392                let dl = d * (f32(sc) - 32.0);
393                for (var l: u32 = 0u; l < 16u; l++) {
394                    let q_idx = q_b_idx + k + l;
395                    let hm_idx = k + l;
396                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
397                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
398                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
399                    let qs_val = (q_byte >> shift) & 3;
400                    sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
401                    src1_i++;
402                }
403            }
404            m <<= 1;
405        }
406    }
407    return sum;
408}
409
410#enddecl(Q3_K)
411
412#decl(Q4_K)
413// 8 blocks of 32 elements each
414fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
415    let block = src0[src0_idx_base + offset];
416    let d = f32(block.d);
417    let m = f32(block.dmin);
418    var sum = 0.0;
419    var src1_i = src1_idx_base + offset * 256;
420    var is: u32 = 0;
421    // 2 blocks each iteration
422    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
423        for (var shift: u32 = 0; shift < 8; shift += 4) {
424            let scale_min = get_scale_min(is, block.scales);
425            is++;
426            let dl = d * scale_min.x;
427            let ml = m * scale_min.y;
428            for (var l: u32 = 0; l < 32; l++) {
429                let q_idx = q_b_idx + l;
430                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
431                let qs_val = (q_byte >> shift) & 0xF;
432                sum += (f32(qs_val) * dl - ml) * src1[src1_i];
433                src1_i++;
434            }
435        }
436    }
437    return sum;
438}
439
440#enddecl(Q4_K)
441
442#decl(Q5_K)
443// 8 blocks of 32 elements each
444fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
445    let block = src0[src0_idx_base + offset];
446    let d = f32(block.d);
447    let m = f32(block.dmin);
448    var sum = 0.0;
449    var src1_i = src1_idx_base + offset * 256;
450    var is: u32 = 0;
451    var u: u32 = 1;
452    // 2 blocks each iteration
453    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
454        for (var shift: u32 = 0; shift < 8; shift += 4) {
455            let scale_min = get_scale_min(is, block.scales);
456            is++;
457            let dl = d * scale_min.x;
458            let ml = m * scale_min.y;
459            for (var l: u32 = 0; l < 32; l++) {
460                let q_idx = q_b_idx + l;
461                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
462                let qh_byte = get_byte(block.qh[l / 4], l % 4);
463                let qs_val = (q_byte >> shift) & 0xF;
464                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
465                sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
466               src1_i++;
467            }
468            u <<= 1;
469        }
470    }
471    return sum;
472}
473
474#enddecl(Q5_K)
475
476#decl(Q6_K)
477// 16 blocks of 16 elements each
478fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
479    let block = src0[src0_idx_base + offset];
480    let d = f32(block.d);
481
482    // convert arrays of f16 -> u32
483    var ql_vals: array<u32, 32>;
484    for (var i: u32 = 0; i < 32; i++) {
485        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
486    }
487    var qh_vals: array<u32, 16>;
488    for (var i: u32 = 0; i < 16; i++) {
489        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
490    }
491    var scale_vals: array<u32, 4>;
492    for (var i: u32 = 0; i < 4; i++) {
493        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
494    }
495
496    var sum = 0.0;
497    var src1_i = src1_idx_base + offset * 256;
498    var qh_b_idx: u32 = 0;
499    var sc_b_idx: u32 = 0;
500    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
501        for (var l: u32 = 0; l < 32; l++) {
502            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
503            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
504            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
505
506            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
507            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
508            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
509            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
510
511            let is = l/16;
512            let is1 = sc_b_idx + is;
513            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
514            let is2 = sc_b_idx + is + 2;
515            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
516            let is3 = sc_b_idx + is + 4;
517            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
518            let is4 = sc_b_idx + is + 6;
519            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
520
521            sum += d * f32(sc1) * q1 * src1[src1_i + l];
522            sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
523            sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
524            sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
525        }
526        src1_i += 128;
527        qh_b_idx += 32;
528        sc_b_idx += 8;
529    }
530    return sum;
531}
532
533#enddecl(Q6_K)
534
535#decl(IQ2_XXS)
536fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
537    let block = src0[src0_idx_base + offset];
538    let d = f32(block.d);
539    var src1_i = src1_idx_base + offset * 256;
540    var sum = 0.0;
541    for (var ib: u32 = 0; ib < 32; ib += 4) {
542        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
543        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
544        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
545        for (var l: u32 = 0; l < 4; l++) {
546            let ig = get_byte(aux0, l) * 8;
547            let is = (aux1 >> (7 * l)) & 127;
548            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
549            for (var j: u32 = 0; j < 8; j++) {
550                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
551                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
552                sum += db * f32(g) * m * src1[src1_i];
553                src1_i++;
554            }
555        }
556    }
557    return sum;
558}
559
560#enddecl(IQ2_XXS)
561
562#decl(IQ2_XS)
563fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
564    let block = src0[src0_idx_base + offset];
565    let d = f32(block.d);
566    var src1_i = src1_idx_base + offset * 256;
567    var scale_vals = array<u32, 2>(
568        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
569        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
570    );
571    var sum = 0.0;
572    for (var ib: u32 = 0; ib < 32; ib += 4) {
573        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
574        let db = array<f32, 2>(
575            d * (0.5 + f32(s & 0xF)) * 0.25,
576            d * (0.5 + f32(s >> 4)) * 0.25
577        );
578        for (var l: u32 = 0; l < 4; l++) {
579            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
580            let ig = (qs_val & 511) * 8;
581            let is = qs_val >> 9;
582            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
583            let dl = db[l/2];
584            for (var j: u32 = 0; j < 8; j++) {
585                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
586                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
587                sum += dl * f32(g) * m * src1[src1_i];
588                src1_i++;
589            }
590        }
591    }
592    return sum;
593}
594
595#enddecl(IQ2_XS)
596
597#decl(IQ2_S)
598fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
599    let block = src0[src0_idx_base + offset];
600    let d = f32(block.d);
601    var src1_i = src1_idx_base + offset * 256;
602    var qs_vals : array<u32, 16>;
603    for (var i: u32 = 0; i < 16; i++) {
604        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
605    }
606    var qh_vals = array<u32, 2>(
607        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
608        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
609    );
610    var scale_vals = array<u32, 2>(
611        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
612        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
613    );
614    var sum = 0.0;
615    for (var ib: u32 = 0; ib < 8; ib ++) {
616        let s = get_byte(scale_vals[ib / 4], ib % 4);
617        let db = array<f32, 2>(
618            d * (0.5 + f32(s & 0xF)) * 0.25,
619            d * (0.5 + f32(s >> 4)) * 0.25
620        );
621        let qs_w = qs_vals[ib];
622        for (var l: u32 = 0; l < 4; l++) {
623            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
624            let ig = (get_byte(qs_w, l) | qh_b) * 8;
625            let signs = get_byte(qs_vals[ib + 8], l);
626            let dl = db[l/2];
627            for (var j: u32 = 0; j < 8; j++) {
628                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
629                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
630                sum += dl * f32(g) * m * src1[src1_i];
631                src1_i++;
632            }
633        }
634    }
635    return sum;
636}
637
638
639#enddecl(IQ2_S)
640
641#decl(IQ3_XSS)
642fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
643    let block = src0[src0_idx_base + offset];
644    let d = f32(block.d);
645    var src1_i = src1_idx_base + offset * 256;
646    var sum = 0.0;
647    for (var ib: u32 = 0; ib < 16; ib += 2) {
648        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
649        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
650        for (var l: u32 = 0; l < 4; l++) {
651            let is = (sc_sign >> (7 * l)) & 127;
652            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
653            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
654            let ig1 = get_byte(ig_val, 0);
655            let ig2 = get_byte(ig_val, 1);
656            for (var j: u32 = 0; j < 4; j++) {
657                let g1 = get_byte(iq3xxs_grid[ig1], j);
658                let g2 = get_byte(iq3xxs_grid[ig2], j);
659                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
660                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
661                sum += db * f32(g1) * m1 * src1[src1_i];
662                sum += db * f32(g2) * m2 * src1[src1_i + 4];
663                src1_i++;
664            }
665            src1_i += 4;
666        }
667    }
668    return sum;
669}
670
671#enddecl(IQ3_XSS)
672
673#decl(IQ3_S)
674fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
675    let block = src0[src0_idx_base + offset];
676    let d = f32(block.d);
677    var src1_i = src1_idx_base + offset * 256;
678    var qh_vals = array<u32, 2>(
679        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
680        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
681    );
682    var sign_vals: array<u32, 8>;
683    for (var i: u32 = 0; i < 8; i++) {
684        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
685    }
686    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
687    var sum = 0.0;
688    for (var ib: u32 = 0; ib < 4; ib++) {
689        let s = get_byte(scale_vals, ib);
690        let db = array<f32, 2>(
691            d * (1.0 + 2.0 * f32(s & 0xF)),
692            d * (1.0 + 2.0 * f32(s >> 4))
693        );
694        for (var k: u32 = 0; k < 2; k++) {
695            let dl = db[k];
696            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
697            let sign_w = sign_vals[ib * 2 + k];
698            for (var l: u32 = 0; l < 4; l++) {
699                let signs = get_byte(sign_w, l);
700                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
701                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
702                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
703                for (var j: u32 = 0; j < 4; j++) {
704                    let g1 = get_byte(iq3s_grid[ig1], j);
705                    let g2 = get_byte(iq3s_grid[ig2], j);
706                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
707                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
708                    sum += dl * f32(g1) * m1 * src1[src1_i];
709                    sum += dl * f32(g2) * m2 * src1[src1_i + 4];
710                    src1_i++;
711                }
712                src1_i += 4;
713            }
714        }
715    }
716    return sum;
717}
718#enddecl(IQ3_S)
719
720#decl(IQ1_S)
721fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
722    let block = src0[src0_idx_base + offset];
723    let d = f32(block.d);
724    var src1_i = src1_idx_base + offset * 256;
725    var sum = 0.0;
726    for (var ib: u32 = 0; ib < 8; ib++) {
727        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
728        let dl = d * (2 * f32((qh >> 12) & 7) + 1);
729        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
730        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
731        for (var l: u32 = 0; l < 4; l++) {
732            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
733            for (var j: u32 = 0; j < 8; j++) {
734                let gw = iq1_grid[(ig + j) / 16];
735                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
736                let gs = bitcast<i32>(g << 30) >> 30;
737                sum += dl * (f32(gs) + delta) * src1[src1_i];
738                src1_i++;
739            }
740        }
741    }
742    return sum;
743}
744
745#enddecl(IQ1_S)
746
747#decl(IQ1_M)
748fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
749    let block = src0[src0_idx_base + offset];
750
751    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
752    let d = f32(bitcast<vec2<f16>>(scale).x);
753    var src1_i = src1_idx_base + offset * 256;
754    var sum = 0.0;
755    for (var ib: u32 = 0; ib < 8; ib++) {
756        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
757        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
758        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
759        var dl = array<f32, 2>(
760            d * f32(2 * s1 + 1),
761            d * f32(2 * s2 + 1)
762        );
763
764        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
765        var idx = array<u32, 4>(
766            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
767            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
768            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
769            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
770        );
771        var delta = array<f32, 4>(
772            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
773            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
774            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
775            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
776        );
777        for (var l: u32 = 0; l < 4; l++) {
778            let ig = idx[l] * 8;
779            for (var j: u32 = 0; j < 8; j++) {
780                let gw = iq1_grid[(ig + j) / 16];
781                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
782                let gs = bitcast<i32>(g << 30) >> 30;
783                sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
784                src1_i++;
785            }
786        }
787    }
788    return sum;
789}
790
791#enddecl(IQ1_M)
792
793#decl(IQ4_NL)
794fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
795    let block = src0[src0_idx_base + offset];
796    let d = f32(block.d);
797    var src1_i = src1_idx_base + offset * 32;
798    var sum = 0.0;
799    var qs: array<u32, 4>;
800    for (var i: u32 = 0; i < 4; i++) {
801        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
802    }
803    for (var j: u32 = 0; j < 16; j++) {
804        let qsb = get_byte(qs[j / 4], j % 4);
805        sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
806        sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
807        src1_i++;
808    }
809    return sum;
810}
811
812#enddecl(IQ4_NL)
813
814#decl(IQ4_XS)
815fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
816    let block = src0[src0_idx_base + offset];
817    let d = f32(block.d);
818    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
819    var src1_i = src1_idx_base + offset * 256;
820    var sum = 0.0;
821    for (var ib: u32 = 0; ib < 8; ib++) {
822        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
823        let dl = d * (f32(ls) - 32.0);
824        for (var j: u32 = 0; j < 16; j++) {
825            let iqs = ib * 16 + j;
826            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
827            sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
828            sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
829            src1_i++;
830        }
831        src1_i += 16;
832    }
833    return sum;
834}
835
836#enddecl(IQ4_XS)
837
838#end(DECLS)
839
840#define(SHADER)
841
842enable f16;
843
844DECLS
845
846struct MulMatParams {
847    offset_src0: u32, // in elements/blocks
848    offset_src1: u32, // in elements/blocks
849    offset_dst: u32, // in elements/blocks
850    m: u32,
851    n: u32,
852    k: u32,
853    // all strides are in elements/blocks
854    stride_01: u32,
855    stride_11: u32,
856    stride_02: u32,
857    stride_12: u32,
858    stride_03: u32,
859    stride_13: u32,
860
861    bs02: u32,
862    bs03: u32,
863    broadcast2: u32,
864    broadcast3: u32
865};
866
867@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
868@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
869@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
870
871@group(0) @binding(3) var<uniform> params: MulMatParams;
872
873@compute @workgroup_size(256)
874fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
875    let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
876    if (global_id.x >= total) {
877        return;
878    }
879
880    let dst2_stride = params.m * params.n;
881    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
882
883    let dst3_idx = global_id.x / dst3_stride;
884    let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
885    let src13_idx = dst3_idx; // src1 is not broadcast
886    let dst3_rem = global_id.x % dst3_stride;
887
888    let dst2_idx = dst3_rem / dst2_stride;
889    let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
890    let src12_idx = dst2_idx; // src1 is not broadcast
891
892    let dst2_rem = dst3_rem % dst2_stride;
893
894    let row = dst2_rem / params.m; // output row
895    let col = dst2_rem % params.m; // output column
896
897    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
898    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
899
900    var sum = 0.0;
901    for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
902        sum += multiply_add(src0_idx_base, src1_idx_base, i);
903    }
904    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
905}
906
907#end(SHADER)