aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl907
1 files changed, 907 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
new file mode 100644
index 0000000..0f8e6e5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
@@ -0,0 +1,907 @@
1#define(VARIANTS)
2
3[
4 {
5 "REPLS": {
6 "SRC0_TYPE" : "f32",
7 "SRC1_TYPE" : "f32",
8 "BLOCK_SIZE" : 1
9 },
10 "DECLS" : ["FLOAT"]
11 },
12 {
13 "REPLS": {
14 "SRC0_TYPE" : "f16",
15 "SRC1_TYPE" : "f16",
16 "BLOCK_SIZE" : 1
17 },
18 "DECLS" : ["FLOAT"]
19 },
20 {
21 "REPLS": {
22 "SRC0_TYPE" : "f16",
23 "SRC1_TYPE" : "f32",
24 "BLOCK_SIZE" : 1
25 },
26 "DECLS" : ["FLOAT"]
27 },
28 {
29 "REPLS": {
30 "SRC0_TYPE": "q4_0",
31 "SRC1_TYPE": "f32",
32 "BLOCK_SIZE": 32
33 },
34 "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
35 },
36 {
37 "REPLS": {
38 "SRC0_TYPE": "q4_1",
39 "SRC1_TYPE": "f32",
40 "BLOCK_SIZE": 32
41 },
42 "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
43 },
44 {
45 "REPLS": {
46 "SRC0_TYPE": "q5_0",
47 "SRC1_TYPE": "f32",
48 "BLOCK_SIZE": 32
49 },
50 "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
51 },
52 {
53 "REPLS": {
54 "SRC0_TYPE": "q5_1",
55 "SRC1_TYPE": "f32",
56 "BLOCK_SIZE": 32
57 },
58 "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
59 },
60 {
61 "REPLS": {
62 "SRC0_TYPE": "q8_0",
63 "SRC1_TYPE": "f32",
64 "BLOCK_SIZE": 32
65 },
66 "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
67 },
68 {
69 "REPLS": {
70 "SRC0_TYPE": "q2_k",
71 "SRC1_TYPE": "f32",
72 "BLOCK_SIZE": 256
73 },
74 "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
75 },
76 {
77 "REPLS": {
78 "SRC0_TYPE": "q3_k",
79 "SRC1_TYPE": "f32",
80 "BLOCK_SIZE": 256
81 },
82 "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
83 },
84 {
85 "REPLS": {
86 "SRC0_TYPE": "q4_k",
87 "SRC1_TYPE": "f32",
88 "BLOCK_SIZE": 256
89 },
90 "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
91 },
92 {
93 "REPLS": {
94 "SRC0_TYPE": "q5_k",
95 "SRC1_TYPE": "f32",
96 "BLOCK_SIZE": 256
97 },
98 "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
99 },
100 {
101 "REPLS": {
102 "SRC0_TYPE": "q6_k",
103 "SRC1_TYPE": "f32",
104 "BLOCK_SIZE": 256
105 },
106 "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
107 },
108 {
109 "REPLS": {
110 "SRC0_TYPE": "iq2_xxs",
111 "SRC1_TYPE": "f32",
112 "BLOCK_SIZE": 256
113 },
114 "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
115 },
116 {
117 "REPLS": {
118 "SRC0_TYPE": "iq2_xs",
119 "SRC1_TYPE": "f32",
120 "BLOCK_SIZE": 256
121 },
122 "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
123 },
124 {
125 "REPLS": {
126 "SRC0_TYPE": "iq2_s",
127 "SRC1_TYPE": "f32",
128 "BLOCK_SIZE": 256
129 },
130 "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
131 },
132 {
133 "REPLS": {
134 "SRC0_TYPE": "iq3_xxs",
135 "SRC1_TYPE": "f32",
136 "BLOCK_SIZE": 256
137 },
138 "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
139 },
140 {
141 "REPLS": {
142 "SRC0_TYPE": "iq3_s",
143 "SRC1_TYPE": "f32",
144 "BLOCK_SIZE": 256
145 },
146 "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
147 },
148 {
149 "REPLS": {
150 "SRC0_TYPE": "iq1_s",
151 "SRC1_TYPE": "f32",
152 "BLOCK_SIZE": 256
153 },
154 "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
155 },
156 {
157 "REPLS": {
158 "SRC0_TYPE": "iq1_m",
159 "SRC1_TYPE": "f32",
160 "BLOCK_SIZE": 256
161 },
162 "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
163 },
164 {
165 "REPLS": {
166 "SRC0_TYPE": "iq4_nl",
167 "SRC1_TYPE": "f32",
168 "BLOCK_SIZE": 32,
169 },
170 "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
171 },
172 {
173 "REPLS": {
174 "SRC0_TYPE": "iq4_xs",
175 "SRC1_TYPE": "f32",
176 "BLOCK_SIZE": 256,
177 },
178 "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
179 }
180]
181
182#end(VARIANTS)
183
184#define(DECLS)
185
186#decl(FLOAT)
187fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
188 return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
189}
190#enddecl(FLOAT)
191
192#decl(Q4_0)
193fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
194 let block_q4_0 = src0[src0_idx_base + offset];
195 let d = f32(block_q4_0.d);
196 var sum: f32 = 0.0;
197 for (var j: u32 = 0; j < 4; j++) {
198 let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
199 for (var k: u32 = 0; k < 4; k++) {
200 let q_byte = get_byte(q_packed, k);
201 let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
202 let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
203 let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
204 sum += q_lo * f32(src1[src1_offset]);
205 sum += q_hi * f32(src1[src1_offset + 16]);
206 }
207 }
208 return sum;
209}
210#enddecl(Q4_0)
211
212#decl(Q4_1)
213fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
214 let block_q4_1 = src0[src0_idx_base + offset];
215 let d = f32(block_q4_1.d);
216 let m = f32(block_q4_1.m);
217 var sum: f32 = 0.0;
218 for (var j: u32 = 0; j < 4; j++) {
219 let q_packed = block_q4_1.qs[j];
220 for (var k: u32 = 0; k < 4; k++) {
221 let q_byte = get_byte(q_packed, k);
222 let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
223 let q_lo = f32(q_byte & 0xF) * d + m;
224 let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
225 sum += q_lo * f32(src1[src1_offset]);
226 sum += q_hi * f32(src1[src1_offset + 16]);
227 }
228 }
229 return sum;
230}
231#enddecl(Q4_1)
232
233#decl(Q5_0)
234fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
235 let block_q5_0 = src0[src0_idx_base + offset];
236 let d = f32(block_q5_0.d);
237 var sum: f32 = 0.0;
238 let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
239 for (var j: u32 = 0; j < 4; j++) {
240 let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
241 for (var k: u32 = 0; k < 4; k++) {
242 let q_byte = get_byte(q_packed, k);
243 let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
244 let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
245 let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
246 let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
247 let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
248 sum += q_lo * f32(src1[src1_offset]);
249 sum += q_hi * f32(src1[src1_offset + 16]);
250 }
251 }
252 return sum;
253}
254#enddecl(Q5_0)
255
256#decl(Q5_1)
257fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
258 let block_q5_1 = src0[src0_idx_base + offset];
259 let d = f32(block_q5_1.d);
260 let m = f32(block_q5_1.m);
261 var sum: f32 = 0.0;
262 for (var j: u32 = 0; j < 4; j++) {
263 let q_packed = block_q5_1.qs[j];
264 for (var k: u32 = 0; k < 4; k++) {
265 let q_byte = get_byte(q_packed, k);
266 let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
267 let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
268 let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
269 let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
270 let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
271 sum += q_lo * f32(src1[src1_offset]);
272 sum += q_hi * f32(src1[src1_offset + 16]);
273 }
274 }
275 return sum;
276}
277#enddecl(Q5_1)
278
279#decl(Q8_0)
280fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
281 let block_q8_0 = src0[src0_idx_base + offset];
282 let d = f32(block_q8_0.d);
283 var sum: f32 = 0.0;
284 for (var j: u32 = 0; j < 8; j++) {
285 let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
286 for (var k: u32 = 0; k < 4; k++) {
287 let q_byte = get_byte_i32(q_packed, k);
288 let q_val = f32(q_byte) * d;
289 let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
290 sum += q_val * f32(src1[src1_offset]);
291 }
292 }
293 return sum;
294}
295#enddecl(Q8_0)
296
297#decl(Q8_1)
298fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
299 let block_q8_1 = src0[src0_idx_base + offset];
300 let d = f32(block_q8_1.d);
301 let m = f32(block_q8_1.m);
302 var sum: f32 = 0.0;
303 for (var j: u32 = 0; j < 8; j++) {
304 let q_packed = block_q8_1.qs[j];
305 for (var k: u32 = 0; k < 4; k++) {
306 let q_byte = get_byte_i32(q_packed, k);
307 let q_val = f32(q_byte) * d + m;
308 let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
309 sum += q_val * f32(src1[src1_offset]);
310 }
311 }
312 return sum;
313}
314#enddecl(Q8_1)
315
316#decl(Q2_K)
317// 16 blocks of 16 elements each
318fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
319 let block = src0[src0_idx_base + offset];
320 let d = f32(block.d);
321 let m = f32(block.dmin);
322 var sum = 0.0;
323 var src1_i = src1_idx_base + offset * 256;
324 var is: u32 = 0;
325 // 2 halves of the block (128 elements each)
326 for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
327 // 4 groups (each group has 2 blocks of 16 elements)
328 for (var shift: u32 = 0; shift < 8; shift += 2) {
329 // 2 blocks
330 for (var k: u32 = 0; k < 32; k += 16) {
331 let sc = get_byte(block.scales[is / 4], is % 4);
332 is++;
333 let dl = d * f32(sc & 0xF);
334 let ml = m * f32(sc >> 4);
335 for (var l: u32 = 0u; l < 16; l++) {
336 let q_idx = q_b_idx + k + l;
337 let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
338 let qs_val = (q_byte >> shift) & 3;
339 sum += (f32(qs_val) * dl - ml) * src1[src1_i];
340 src1_i++;
341 }
342 }
343 }
344 }
345 return sum;
346}
347
348#enddecl(Q2_K)
349
350#decl(Q3_K)
351// 16 blocks of 16 elements each
352fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
353 let block = src0[src0_idx_base + offset];
354 let d = f32(block.d);
355
356 // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
357 // and 2-bits from the last 4 bytes
358 let kmask1: u32 = 0x03030303;
359 let kmask2: u32 = 0x0f0f0f0f;
360 var scale_vals: array<u32, 4>;
361 for (var i: u32 = 0; i < 4; i++) {
362 scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
363 }
364 var tmp: u32 = scale_vals[2];
365 scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
366 scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
367 scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
368 scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
369
370 // convert arrays of f16 -> u32
371 var hmask_vals: array<u32, 8>;
372 for (var i: u32 = 0; i < 8; i++) {
373 hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
374 }
375 var qs_vals: array<u32, 16>;
376 for (var i: u32 = 0; i < 16; i++) {
377 qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
378 }
379
380 var sum = 0.0;
381 var src1_i = src1_idx_base + offset * 256;
382 var is: u32 = 0;
383 var m: u32 = 1;
384 // 2 halves of the block (128 elements each)
385 for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
386 // 4 groups (each group has 2 blocks of 16 elements)
387 for (var shift: u32 = 0; shift < 8; shift += 2) {
388 // 2 blocks
389 for (var k: u32 = 0; k < 32; k += 16) {
390 let sc = get_byte(scale_vals[is / 4], is % 4);
391 is++;
392 let dl = d * (f32(sc) - 32.0);
393 for (var l: u32 = 0u; l < 16u; l++) {
394 let q_idx = q_b_idx + k + l;
395 let hm_idx = k + l;
396 let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
397 let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
398 let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
399 let qs_val = (q_byte >> shift) & 3;
400 sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
401 src1_i++;
402 }
403 }
404 m <<= 1;
405 }
406 }
407 return sum;
408}
409
410#enddecl(Q3_K)
411
412#decl(Q4_K)
413// 8 blocks of 32 elements each
414fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
415 let block = src0[src0_idx_base + offset];
416 let d = f32(block.d);
417 let m = f32(block.dmin);
418 var sum = 0.0;
419 var src1_i = src1_idx_base + offset * 256;
420 var is: u32 = 0;
421 // 2 blocks each iteration
422 for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
423 for (var shift: u32 = 0; shift < 8; shift += 4) {
424 let scale_min = get_scale_min(is, block.scales);
425 is++;
426 let dl = d * scale_min.x;
427 let ml = m * scale_min.y;
428 for (var l: u32 = 0; l < 32; l++) {
429 let q_idx = q_b_idx + l;
430 let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
431 let qs_val = (q_byte >> shift) & 0xF;
432 sum += (f32(qs_val) * dl - ml) * src1[src1_i];
433 src1_i++;
434 }
435 }
436 }
437 return sum;
438}
439
440#enddecl(Q4_K)
441
442#decl(Q5_K)
443// 8 blocks of 32 elements each
444fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
445 let block = src0[src0_idx_base + offset];
446 let d = f32(block.d);
447 let m = f32(block.dmin);
448 var sum = 0.0;
449 var src1_i = src1_idx_base + offset * 256;
450 var is: u32 = 0;
451 var u: u32 = 1;
452 // 2 blocks each iteration
453 for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
454 for (var shift: u32 = 0; shift < 8; shift += 4) {
455 let scale_min = get_scale_min(is, block.scales);
456 is++;
457 let dl = d * scale_min.x;
458 let ml = m * scale_min.y;
459 for (var l: u32 = 0; l < 32; l++) {
460 let q_idx = q_b_idx + l;
461 let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
462 let qh_byte = get_byte(block.qh[l / 4], l % 4);
463 let qs_val = (q_byte >> shift) & 0xF;
464 let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
465 sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
466 src1_i++;
467 }
468 u <<= 1;
469 }
470 }
471 return sum;
472}
473
474#enddecl(Q5_K)
475
476#decl(Q6_K)
477// 16 blocks of 16 elements each
478fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
479 let block = src0[src0_idx_base + offset];
480 let d = f32(block.d);
481
482 // convert arrays of f16 -> u32
483 var ql_vals: array<u32, 32>;
484 for (var i: u32 = 0; i < 32; i++) {
485 ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
486 }
487 var qh_vals: array<u32, 16>;
488 for (var i: u32 = 0; i < 16; i++) {
489 qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
490 }
491 var scale_vals: array<u32, 4>;
492 for (var i: u32 = 0; i < 4; i++) {
493 scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
494 }
495
496 var sum = 0.0;
497 var src1_i = src1_idx_base + offset * 256;
498 var qh_b_idx: u32 = 0;
499 var sc_b_idx: u32 = 0;
500 for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
501 for (var l: u32 = 0; l < 32; l++) {
502 let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
503 let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
504 let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
505
506 let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
507 let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
508 let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
509 let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
510
511 let is = l/16;
512 let is1 = sc_b_idx + is;
513 let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
514 let is2 = sc_b_idx + is + 2;
515 let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
516 let is3 = sc_b_idx + is + 4;
517 let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
518 let is4 = sc_b_idx + is + 6;
519 let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
520
521 sum += d * f32(sc1) * q1 * src1[src1_i + l];
522 sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
523 sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
524 sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
525 }
526 src1_i += 128;
527 qh_b_idx += 32;
528 sc_b_idx += 8;
529 }
530 return sum;
531}
532
533#enddecl(Q6_K)
534
535#decl(IQ2_XXS)
536fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
537 let block = src0[src0_idx_base + offset];
538 let d = f32(block.d);
539 var src1_i = src1_idx_base + offset * 256;
540 var sum = 0.0;
541 for (var ib: u32 = 0; ib < 32; ib += 4) {
542 let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
543 let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
544 let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
545 for (var l: u32 = 0; l < 4; l++) {
546 let ig = get_byte(aux0, l) * 8;
547 let is = (aux1 >> (7 * l)) & 127;
548 let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
549 for (var j: u32 = 0; j < 8; j++) {
550 let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
551 let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
552 sum += db * f32(g) * m * src1[src1_i];
553 src1_i++;
554 }
555 }
556 }
557 return sum;
558}
559
560#enddecl(IQ2_XXS)
561
562#decl(IQ2_XS)
563fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
564 let block = src0[src0_idx_base + offset];
565 let d = f32(block.d);
566 var src1_i = src1_idx_base + offset * 256;
567 var scale_vals = array<u32, 2>(
568 bitcast<u32>(vec2(block.scales[0], block.scales[1])),
569 bitcast<u32>(vec2(block.scales[2], block.scales[3]))
570 );
571 var sum = 0.0;
572 for (var ib: u32 = 0; ib < 32; ib += 4) {
573 let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
574 let db = array<f32, 2>(
575 d * (0.5 + f32(s & 0xF)) * 0.25,
576 d * (0.5 + f32(s >> 4)) * 0.25
577 );
578 for (var l: u32 = 0; l < 4; l++) {
579 let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
580 let ig = (qs_val & 511) * 8;
581 let is = qs_val >> 9;
582 let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
583 let dl = db[l/2];
584 for (var j: u32 = 0; j < 8; j++) {
585 let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
586 let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
587 sum += dl * f32(g) * m * src1[src1_i];
588 src1_i++;
589 }
590 }
591 }
592 return sum;
593}
594
595#enddecl(IQ2_XS)
596
597#decl(IQ2_S)
598fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
599 let block = src0[src0_idx_base + offset];
600 let d = f32(block.d);
601 var src1_i = src1_idx_base + offset * 256;
602 var qs_vals : array<u32, 16>;
603 for (var i: u32 = 0; i < 16; i++) {
604 qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
605 }
606 var qh_vals = array<u32, 2>(
607 bitcast<u32>(vec2(block.qh[0], block.qh[1])),
608 bitcast<u32>(vec2(block.qh[2], block.qh[3]))
609 );
610 var scale_vals = array<u32, 2>(
611 bitcast<u32>(vec2(block.scales[0], block.scales[1])),
612 bitcast<u32>(vec2(block.scales[2], block.scales[3]))
613 );
614 var sum = 0.0;
615 for (var ib: u32 = 0; ib < 8; ib ++) {
616 let s = get_byte(scale_vals[ib / 4], ib % 4);
617 let db = array<f32, 2>(
618 d * (0.5 + f32(s & 0xF)) * 0.25,
619 d * (0.5 + f32(s >> 4)) * 0.25
620 );
621 let qs_w = qs_vals[ib];
622 for (var l: u32 = 0; l < 4; l++) {
623 let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
624 let ig = (get_byte(qs_w, l) | qh_b) * 8;
625 let signs = get_byte(qs_vals[ib + 8], l);
626 let dl = db[l/2];
627 for (var j: u32 = 0; j < 8; j++) {
628 let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
629 let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
630 sum += dl * f32(g) * m * src1[src1_i];
631 src1_i++;
632 }
633 }
634 }
635 return sum;
636}
637
638
639#enddecl(IQ2_S)
640
641#decl(IQ3_XSS)
642fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
643 let block = src0[src0_idx_base + offset];
644 let d = f32(block.d);
645 var src1_i = src1_idx_base + offset * 256;
646 var sum = 0.0;
647 for (var ib: u32 = 0; ib < 16; ib += 2) {
648 let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
649 let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
650 for (var l: u32 = 0; l < 4; l++) {
651 let is = (sc_sign >> (7 * l)) & 127;
652 let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
653 let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
654 let ig1 = get_byte(ig_val, 0);
655 let ig2 = get_byte(ig_val, 1);
656 for (var j: u32 = 0; j < 4; j++) {
657 let g1 = get_byte(iq3xxs_grid[ig1], j);
658 let g2 = get_byte(iq3xxs_grid[ig2], j);
659 let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
660 let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
661 sum += db * f32(g1) * m1 * src1[src1_i];
662 sum += db * f32(g2) * m2 * src1[src1_i + 4];
663 src1_i++;
664 }
665 src1_i += 4;
666 }
667 }
668 return sum;
669}
670
671#enddecl(IQ3_XSS)
672
673#decl(IQ3_S)
674fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
675 let block = src0[src0_idx_base + offset];
676 let d = f32(block.d);
677 var src1_i = src1_idx_base + offset * 256;
678 var qh_vals = array<u32, 2>(
679 bitcast<u32>(vec2(block.qh[0], block.qh[1])),
680 bitcast<u32>(vec2(block.qh[2], block.qh[3]))
681 );
682 var sign_vals: array<u32, 8>;
683 for (var i: u32 = 0; i < 8; i++) {
684 sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
685 }
686 var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
687 var sum = 0.0;
688 for (var ib: u32 = 0; ib < 4; ib++) {
689 let s = get_byte(scale_vals, ib);
690 let db = array<f32, 2>(
691 d * (1.0 + 2.0 * f32(s & 0xF)),
692 d * (1.0 + 2.0 * f32(s >> 4))
693 );
694 for (var k: u32 = 0; k < 2; k++) {
695 let dl = db[k];
696 let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
697 let sign_w = sign_vals[ib * 2 + k];
698 for (var l: u32 = 0; l < 4; l++) {
699 let signs = get_byte(sign_w, l);
700 let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
701 let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
702 let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
703 for (var j: u32 = 0; j < 4; j++) {
704 let g1 = get_byte(iq3s_grid[ig1], j);
705 let g2 = get_byte(iq3s_grid[ig2], j);
706 let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
707 let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
708 sum += dl * f32(g1) * m1 * src1[src1_i];
709 sum += dl * f32(g2) * m2 * src1[src1_i + 4];
710 src1_i++;
711 }
712 src1_i += 4;
713 }
714 }
715 }
716 return sum;
717}
718#enddecl(IQ3_S)
719
720#decl(IQ1_S)
721fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
722 let block = src0[src0_idx_base + offset];
723 let d = f32(block.d);
724 var src1_i = src1_idx_base + offset * 256;
725 var sum = 0.0;
726 for (var ib: u32 = 0; ib < 8; ib++) {
727 let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
728 let dl = d * (2 * f32((qh >> 12) & 7) + 1);
729 let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
730 let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
731 for (var l: u32 = 0; l < 4; l++) {
732 let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
733 for (var j: u32 = 0; j < 8; j++) {
734 let gw = iq1_grid[(ig + j) / 16];
735 let g = (gw >> (((ig + j) % 16) * 2)) & 3;
736 let gs = bitcast<i32>(g << 30) >> 30;
737 sum += dl * (f32(gs) + delta) * src1[src1_i];
738 src1_i++;
739 }
740 }
741 }
742 return sum;
743}
744
745#enddecl(IQ1_S)
746
747#decl(IQ1_M)
748fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
749 let block = src0[src0_idx_base + offset];
750
751 let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
752 let d = f32(bitcast<vec2<f16>>(scale).x);
753 var src1_i = src1_idx_base + offset * 256;
754 var sum = 0.0;
755 for (var ib: u32 = 0; ib < 8; ib++) {
756 let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
757 let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
758 let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
759 var dl = array<f32, 2>(
760 d * f32(2 * s1 + 1),
761 d * f32(2 * s2 + 1)
762 );
763
764 let qh = block.qh[ib / 2] >> (16 * (ib % 2));
765 var idx = array<u32, 4>(
766 get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
767 get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
768 get_byte(block.qs[ib], 2) | ((qh) & 0x700),
769 get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
770 );
771 var delta = array<f32, 4>(
772 select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
773 select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
774 select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
775 select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
776 );
777 for (var l: u32 = 0; l < 4; l++) {
778 let ig = idx[l] * 8;
779 for (var j: u32 = 0; j < 8; j++) {
780 let gw = iq1_grid[(ig + j) / 16];
781 let g = (gw >> (((ig + j) % 16) * 2)) & 3;
782 let gs = bitcast<i32>(g << 30) >> 30;
783 sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
784 src1_i++;
785 }
786 }
787 }
788 return sum;
789}
790
791#enddecl(IQ1_M)
792
793#decl(IQ4_NL)
794fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
795 let block = src0[src0_idx_base + offset];
796 let d = f32(block.d);
797 var src1_i = src1_idx_base + offset * 32;
798 var sum = 0.0;
799 var qs: array<u32, 4>;
800 for (var i: u32 = 0; i < 4; i++) {
801 qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
802 }
803 for (var j: u32 = 0; j < 16; j++) {
804 let qsb = get_byte(qs[j / 4], j % 4);
805 sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
806 sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
807 src1_i++;
808 }
809 return sum;
810}
811
812#enddecl(IQ4_NL)
813
814#decl(IQ4_XS)
815fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
816 let block = src0[src0_idx_base + offset];
817 let d = f32(block.d);
818 let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
819 var src1_i = src1_idx_base + offset * 256;
820 var sum = 0.0;
821 for (var ib: u32 = 0; ib < 8; ib++) {
822 let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
823 let dl = d * (f32(ls) - 32.0);
824 for (var j: u32 = 0; j < 16; j++) {
825 let iqs = ib * 16 + j;
826 let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
827 sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
828 sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
829 src1_i++;
830 }
831 src1_i += 16;
832 }
833 return sum;
834}
835
836#enddecl(IQ4_XS)
837
838#end(DECLS)
839
840#define(SHADER)
841
842enable f16;
843
844DECLS
845
846struct MulMatParams {
847 offset_src0: u32, // in elements/blocks
848 offset_src1: u32, // in elements/blocks
849 offset_dst: u32, // in elements/blocks
850 m: u32,
851 n: u32,
852 k: u32,
853 // all strides are in elements/blocks
854 stride_01: u32,
855 stride_11: u32,
856 stride_02: u32,
857 stride_12: u32,
858 stride_03: u32,
859 stride_13: u32,
860
861 bs02: u32,
862 bs03: u32,
863 broadcast2: u32,
864 broadcast3: u32
865};
866
867@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
868@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
869@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
870
871@group(0) @binding(3) var<uniform> params: MulMatParams;
872
873@compute @workgroup_size(256)
874fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
875 let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
876 if (global_id.x >= total) {
877 return;
878 }
879
880 let dst2_stride = params.m * params.n;
881 let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
882
883 let dst3_idx = global_id.x / dst3_stride;
884 let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
885 let src13_idx = dst3_idx; // src1 is not broadcast
886 let dst3_rem = global_id.x % dst3_stride;
887
888 let dst2_idx = dst3_rem / dst2_stride;
889 let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
890 let src12_idx = dst2_idx; // src1 is not broadcast
891
892 let dst2_rem = dst3_rem % dst2_stride;
893
894 let row = dst2_rem / params.m; // output row
895 let col = dst2_rem % params.m; // output column
896
897 let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
898 let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
899
900 var sum = 0.0;
901 for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
902 sum += multiply_add(src0_idx_base, src1_idx_base, i);
903 }
904 dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
905}
906
907#end(SHADER)