diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl | 907 |
1 files changed, 907 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl new file mode 100644 index 0000000..0f8e6e5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl | |||
| @@ -0,0 +1,907 @@ | |||
| 1 | #define(VARIANTS) | ||
| 2 | |||
| 3 | [ | ||
| 4 | { | ||
| 5 | "REPLS": { | ||
| 6 | "SRC0_TYPE" : "f32", | ||
| 7 | "SRC1_TYPE" : "f32", | ||
| 8 | "BLOCK_SIZE" : 1 | ||
| 9 | }, | ||
| 10 | "DECLS" : ["FLOAT"] | ||
| 11 | }, | ||
| 12 | { | ||
| 13 | "REPLS": { | ||
| 14 | "SRC0_TYPE" : "f16", | ||
| 15 | "SRC1_TYPE" : "f16", | ||
| 16 | "BLOCK_SIZE" : 1 | ||
| 17 | }, | ||
| 18 | "DECLS" : ["FLOAT"] | ||
| 19 | }, | ||
| 20 | { | ||
| 21 | "REPLS": { | ||
| 22 | "SRC0_TYPE" : "f16", | ||
| 23 | "SRC1_TYPE" : "f32", | ||
| 24 | "BLOCK_SIZE" : 1 | ||
| 25 | }, | ||
| 26 | "DECLS" : ["FLOAT"] | ||
| 27 | }, | ||
| 28 | { | ||
| 29 | "REPLS": { | ||
| 30 | "SRC0_TYPE": "q4_0", | ||
| 31 | "SRC1_TYPE": "f32", | ||
| 32 | "BLOCK_SIZE": 32 | ||
| 33 | }, | ||
| 34 | "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] | ||
| 35 | }, | ||
| 36 | { | ||
| 37 | "REPLS": { | ||
| 38 | "SRC0_TYPE": "q4_1", | ||
| 39 | "SRC1_TYPE": "f32", | ||
| 40 | "BLOCK_SIZE": 32 | ||
| 41 | }, | ||
| 42 | "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] | ||
| 43 | }, | ||
| 44 | { | ||
| 45 | "REPLS": { | ||
| 46 | "SRC0_TYPE": "q5_0", | ||
| 47 | "SRC1_TYPE": "f32", | ||
| 48 | "BLOCK_SIZE": 32 | ||
| 49 | }, | ||
| 50 | "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] | ||
| 51 | }, | ||
| 52 | { | ||
| 53 | "REPLS": { | ||
| 54 | "SRC0_TYPE": "q5_1", | ||
| 55 | "SRC1_TYPE": "f32", | ||
| 56 | "BLOCK_SIZE": 32 | ||
| 57 | }, | ||
| 58 | "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] | ||
| 59 | }, | ||
| 60 | { | ||
| 61 | "REPLS": { | ||
| 62 | "SRC0_TYPE": "q8_0", | ||
| 63 | "SRC1_TYPE": "f32", | ||
| 64 | "BLOCK_SIZE": 32 | ||
| 65 | }, | ||
| 66 | "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] | ||
| 67 | }, | ||
| 68 | { | ||
| 69 | "REPLS": { | ||
| 70 | "SRC0_TYPE": "q2_k", | ||
| 71 | "SRC1_TYPE": "f32", | ||
| 72 | "BLOCK_SIZE": 256 | ||
| 73 | }, | ||
| 74 | "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] | ||
| 75 | }, | ||
| 76 | { | ||
| 77 | "REPLS": { | ||
| 78 | "SRC0_TYPE": "q3_k", | ||
| 79 | "SRC1_TYPE": "f32", | ||
| 80 | "BLOCK_SIZE": 256 | ||
| 81 | }, | ||
| 82 | "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] | ||
| 83 | }, | ||
| 84 | { | ||
| 85 | "REPLS": { | ||
| 86 | "SRC0_TYPE": "q4_k", | ||
| 87 | "SRC1_TYPE": "f32", | ||
| 88 | "BLOCK_SIZE": 256 | ||
| 89 | }, | ||
| 90 | "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] | ||
| 91 | }, | ||
| 92 | { | ||
| 93 | "REPLS": { | ||
| 94 | "SRC0_TYPE": "q5_k", | ||
| 95 | "SRC1_TYPE": "f32", | ||
| 96 | "BLOCK_SIZE": 256 | ||
| 97 | }, | ||
| 98 | "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] | ||
| 99 | }, | ||
| 100 | { | ||
| 101 | "REPLS": { | ||
| 102 | "SRC0_TYPE": "q6_k", | ||
| 103 | "SRC1_TYPE": "f32", | ||
| 104 | "BLOCK_SIZE": 256 | ||
| 105 | }, | ||
| 106 | "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] | ||
| 107 | }, | ||
| 108 | { | ||
| 109 | "REPLS": { | ||
| 110 | "SRC0_TYPE": "iq2_xxs", | ||
| 111 | "SRC1_TYPE": "f32", | ||
| 112 | "BLOCK_SIZE": 256 | ||
| 113 | }, | ||
| 114 | "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] | ||
| 115 | }, | ||
| 116 | { | ||
| 117 | "REPLS": { | ||
| 118 | "SRC0_TYPE": "iq2_xs", | ||
| 119 | "SRC1_TYPE": "f32", | ||
| 120 | "BLOCK_SIZE": 256 | ||
| 121 | }, | ||
| 122 | "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] | ||
| 123 | }, | ||
| 124 | { | ||
| 125 | "REPLS": { | ||
| 126 | "SRC0_TYPE": "iq2_s", | ||
| 127 | "SRC1_TYPE": "f32", | ||
| 128 | "BLOCK_SIZE": 256 | ||
| 129 | }, | ||
| 130 | "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] | ||
| 131 | }, | ||
| 132 | { | ||
| 133 | "REPLS": { | ||
| 134 | "SRC0_TYPE": "iq3_xxs", | ||
| 135 | "SRC1_TYPE": "f32", | ||
| 136 | "BLOCK_SIZE": 256 | ||
| 137 | }, | ||
| 138 | "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] | ||
| 139 | }, | ||
| 140 | { | ||
| 141 | "REPLS": { | ||
| 142 | "SRC0_TYPE": "iq3_s", | ||
| 143 | "SRC1_TYPE": "f32", | ||
| 144 | "BLOCK_SIZE": 256 | ||
| 145 | }, | ||
| 146 | "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] | ||
| 147 | }, | ||
| 148 | { | ||
| 149 | "REPLS": { | ||
| 150 | "SRC0_TYPE": "iq1_s", | ||
| 151 | "SRC1_TYPE": "f32", | ||
| 152 | "BLOCK_SIZE": 256 | ||
| 153 | }, | ||
| 154 | "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] | ||
| 155 | }, | ||
| 156 | { | ||
| 157 | "REPLS": { | ||
| 158 | "SRC0_TYPE": "iq1_m", | ||
| 159 | "SRC1_TYPE": "f32", | ||
| 160 | "BLOCK_SIZE": 256 | ||
| 161 | }, | ||
| 162 | "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] | ||
| 163 | }, | ||
| 164 | { | ||
| 165 | "REPLS": { | ||
| 166 | "SRC0_TYPE": "iq4_nl", | ||
| 167 | "SRC1_TYPE": "f32", | ||
| 168 | "BLOCK_SIZE": 32, | ||
| 169 | }, | ||
| 170 | "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] | ||
| 171 | }, | ||
| 172 | { | ||
| 173 | "REPLS": { | ||
| 174 | "SRC0_TYPE": "iq4_xs", | ||
| 175 | "SRC1_TYPE": "f32", | ||
| 176 | "BLOCK_SIZE": 256, | ||
| 177 | }, | ||
| 178 | "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] | ||
| 179 | } | ||
| 180 | ] | ||
| 181 | |||
| 182 | #end(VARIANTS) | ||
| 183 | |||
| 184 | #define(DECLS) | ||
| 185 | |||
| 186 | #decl(FLOAT) | ||
| 187 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 188 | return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); | ||
| 189 | } | ||
| 190 | #enddecl(FLOAT) | ||
| 191 | |||
| 192 | #decl(Q4_0) | ||
| 193 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 194 | let block_q4_0 = src0[src0_idx_base + offset]; | ||
| 195 | let d = f32(block_q4_0.d); | ||
| 196 | var sum: f32 = 0.0; | ||
| 197 | for (var j: u32 = 0; j < 4; j++) { | ||
| 198 | let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); | ||
| 199 | for (var k: u32 = 0; k < 4; k++) { | ||
| 200 | let q_byte = get_byte(q_packed, k); | ||
| 201 | let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; | ||
| 202 | let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; | ||
| 203 | let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; | ||
| 204 | sum += q_lo * f32(src1[src1_offset]); | ||
| 205 | sum += q_hi * f32(src1[src1_offset + 16]); | ||
| 206 | } | ||
| 207 | } | ||
| 208 | return sum; | ||
| 209 | } | ||
| 210 | #enddecl(Q4_0) | ||
| 211 | |||
| 212 | #decl(Q4_1) | ||
| 213 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 214 | let block_q4_1 = src0[src0_idx_base + offset]; | ||
| 215 | let d = f32(block_q4_1.d); | ||
| 216 | let m = f32(block_q4_1.m); | ||
| 217 | var sum: f32 = 0.0; | ||
| 218 | for (var j: u32 = 0; j < 4; j++) { | ||
| 219 | let q_packed = block_q4_1.qs[j]; | ||
| 220 | for (var k: u32 = 0; k < 4; k++) { | ||
| 221 | let q_byte = get_byte(q_packed, k); | ||
| 222 | let q_hi = f32((q_byte >> 4) & 0xF) * d + m; | ||
| 223 | let q_lo = f32(q_byte & 0xF) * d + m; | ||
| 224 | let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; | ||
| 225 | sum += q_lo * f32(src1[src1_offset]); | ||
| 226 | sum += q_hi * f32(src1[src1_offset + 16]); | ||
| 227 | } | ||
| 228 | } | ||
| 229 | return sum; | ||
| 230 | } | ||
| 231 | #enddecl(Q4_1) | ||
| 232 | |||
| 233 | #decl(Q5_0) | ||
| 234 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 235 | let block_q5_0 = src0[src0_idx_base + offset]; | ||
| 236 | let d = f32(block_q5_0.d); | ||
| 237 | var sum: f32 = 0.0; | ||
| 238 | let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); | ||
| 239 | for (var j: u32 = 0; j < 4; j++) { | ||
| 240 | let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); | ||
| 241 | for (var k: u32 = 0; k < 4; k++) { | ||
| 242 | let q_byte = get_byte(q_packed, k); | ||
| 243 | let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; | ||
| 244 | let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; | ||
| 245 | let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; | ||
| 246 | let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; | ||
| 247 | let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; | ||
| 248 | sum += q_lo * f32(src1[src1_offset]); | ||
| 249 | sum += q_hi * f32(src1[src1_offset + 16]); | ||
| 250 | } | ||
| 251 | } | ||
| 252 | return sum; | ||
| 253 | } | ||
| 254 | #enddecl(Q5_0) | ||
| 255 | |||
| 256 | #decl(Q5_1) | ||
| 257 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 258 | let block_q5_1 = src0[src0_idx_base + offset]; | ||
| 259 | let d = f32(block_q5_1.d); | ||
| 260 | let m = f32(block_q5_1.m); | ||
| 261 | var sum: f32 = 0.0; | ||
| 262 | for (var j: u32 = 0; j < 4; j++) { | ||
| 263 | let q_packed = block_q5_1.qs[j]; | ||
| 264 | for (var k: u32 = 0; k < 4; k++) { | ||
| 265 | let q_byte = get_byte(q_packed, k); | ||
| 266 | let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; | ||
| 267 | let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; | ||
| 268 | let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; | ||
| 269 | let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; | ||
| 270 | let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; | ||
| 271 | sum += q_lo * f32(src1[src1_offset]); | ||
| 272 | sum += q_hi * f32(src1[src1_offset + 16]); | ||
| 273 | } | ||
| 274 | } | ||
| 275 | return sum; | ||
| 276 | } | ||
| 277 | #enddecl(Q5_1) | ||
| 278 | |||
| 279 | #decl(Q8_0) | ||
| 280 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 281 | let block_q8_0 = src0[src0_idx_base + offset]; | ||
| 282 | let d = f32(block_q8_0.d); | ||
| 283 | var sum: f32 = 0.0; | ||
| 284 | for (var j: u32 = 0; j < 8; j++) { | ||
| 285 | let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); | ||
| 286 | for (var k: u32 = 0; k < 4; k++) { | ||
| 287 | let q_byte = get_byte_i32(q_packed, k); | ||
| 288 | let q_val = f32(q_byte) * d; | ||
| 289 | let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; | ||
| 290 | sum += q_val * f32(src1[src1_offset]); | ||
| 291 | } | ||
| 292 | } | ||
| 293 | return sum; | ||
| 294 | } | ||
| 295 | #enddecl(Q8_0) | ||
| 296 | |||
| 297 | #decl(Q8_1) | ||
| 298 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 299 | let block_q8_1 = src0[src0_idx_base + offset]; | ||
| 300 | let d = f32(block_q8_1.d); | ||
| 301 | let m = f32(block_q8_1.m); | ||
| 302 | var sum: f32 = 0.0; | ||
| 303 | for (var j: u32 = 0; j < 8; j++) { | ||
| 304 | let q_packed = block_q8_1.qs[j]; | ||
| 305 | for (var k: u32 = 0; k < 4; k++) { | ||
| 306 | let q_byte = get_byte_i32(q_packed, k); | ||
| 307 | let q_val = f32(q_byte) * d + m; | ||
| 308 | let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; | ||
| 309 | sum += q_val * f32(src1[src1_offset]); | ||
| 310 | } | ||
| 311 | } | ||
| 312 | return sum; | ||
| 313 | } | ||
| 314 | #enddecl(Q8_1) | ||
| 315 | |||
| 316 | #decl(Q2_K) | ||
| 317 | // 16 blocks of 16 elements each | ||
| 318 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 319 | let block = src0[src0_idx_base + offset]; | ||
| 320 | let d = f32(block.d); | ||
| 321 | let m = f32(block.dmin); | ||
| 322 | var sum = 0.0; | ||
| 323 | var src1_i = src1_idx_base + offset * 256; | ||
| 324 | var is: u32 = 0; | ||
| 325 | // 2 halves of the block (128 elements each) | ||
| 326 | for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { | ||
| 327 | // 4 groups (each group has 2 blocks of 16 elements) | ||
| 328 | for (var shift: u32 = 0; shift < 8; shift += 2) { | ||
| 329 | // 2 blocks | ||
| 330 | for (var k: u32 = 0; k < 32; k += 16) { | ||
| 331 | let sc = get_byte(block.scales[is / 4], is % 4); | ||
| 332 | is++; | ||
| 333 | let dl = d * f32(sc & 0xF); | ||
| 334 | let ml = m * f32(sc >> 4); | ||
| 335 | for (var l: u32 = 0u; l < 16; l++) { | ||
| 336 | let q_idx = q_b_idx + k + l; | ||
| 337 | let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); | ||
| 338 | let qs_val = (q_byte >> shift) & 3; | ||
| 339 | sum += (f32(qs_val) * dl - ml) * src1[src1_i]; | ||
| 340 | src1_i++; | ||
| 341 | } | ||
| 342 | } | ||
| 343 | } | ||
| 344 | } | ||
| 345 | return sum; | ||
| 346 | } | ||
| 347 | |||
| 348 | #enddecl(Q2_K) | ||
| 349 | |||
| 350 | #decl(Q3_K) | ||
| 351 | // 16 blocks of 16 elements each | ||
| 352 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 353 | let block = src0[src0_idx_base + offset]; | ||
| 354 | let d = f32(block.d); | ||
| 355 | |||
| 356 | // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, | ||
| 357 | // and 2-bits from the last 4 bytes | ||
| 358 | let kmask1: u32 = 0x03030303; | ||
| 359 | let kmask2: u32 = 0x0f0f0f0f; | ||
| 360 | var scale_vals: array<u32, 4>; | ||
| 361 | for (var i: u32 = 0; i < 4; i++) { | ||
| 362 | scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); | ||
| 363 | } | ||
| 364 | var tmp: u32 = scale_vals[2]; | ||
| 365 | scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); | ||
| 366 | scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); | ||
| 367 | scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); | ||
| 368 | scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); | ||
| 369 | |||
| 370 | // convert arrays of f16 -> u32 | ||
| 371 | var hmask_vals: array<u32, 8>; | ||
| 372 | for (var i: u32 = 0; i < 8; i++) { | ||
| 373 | hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); | ||
| 374 | } | ||
| 375 | var qs_vals: array<u32, 16>; | ||
| 376 | for (var i: u32 = 0; i < 16; i++) { | ||
| 377 | qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1])); | ||
| 378 | } | ||
| 379 | |||
| 380 | var sum = 0.0; | ||
| 381 | var src1_i = src1_idx_base + offset * 256; | ||
| 382 | var is: u32 = 0; | ||
| 383 | var m: u32 = 1; | ||
| 384 | // 2 halves of the block (128 elements each) | ||
| 385 | for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { | ||
| 386 | // 4 groups (each group has 2 blocks of 16 elements) | ||
| 387 | for (var shift: u32 = 0; shift < 8; shift += 2) { | ||
| 388 | // 2 blocks | ||
| 389 | for (var k: u32 = 0; k < 32; k += 16) { | ||
| 390 | let sc = get_byte(scale_vals[is / 4], is % 4); | ||
| 391 | is++; | ||
| 392 | let dl = d * (f32(sc) - 32.0); | ||
| 393 | for (var l: u32 = 0u; l < 16u; l++) { | ||
| 394 | let q_idx = q_b_idx + k + l; | ||
| 395 | let hm_idx = k + l; | ||
| 396 | let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); | ||
| 397 | let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); | ||
| 398 | let hm = select(4.0, 0.0, (hmask_byte & m) != 0); | ||
| 399 | let qs_val = (q_byte >> shift) & 3; | ||
| 400 | sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; | ||
| 401 | src1_i++; | ||
| 402 | } | ||
| 403 | } | ||
| 404 | m <<= 1; | ||
| 405 | } | ||
| 406 | } | ||
| 407 | return sum; | ||
| 408 | } | ||
| 409 | |||
| 410 | #enddecl(Q3_K) | ||
| 411 | |||
| 412 | #decl(Q4_K) | ||
| 413 | // 8 blocks of 32 elements each | ||
| 414 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 415 | let block = src0[src0_idx_base + offset]; | ||
| 416 | let d = f32(block.d); | ||
| 417 | let m = f32(block.dmin); | ||
| 418 | var sum = 0.0; | ||
| 419 | var src1_i = src1_idx_base + offset * 256; | ||
| 420 | var is: u32 = 0; | ||
| 421 | // 2 blocks each iteration | ||
| 422 | for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { | ||
| 423 | for (var shift: u32 = 0; shift < 8; shift += 4) { | ||
| 424 | let scale_min = get_scale_min(is, block.scales); | ||
| 425 | is++; | ||
| 426 | let dl = d * scale_min.x; | ||
| 427 | let ml = m * scale_min.y; | ||
| 428 | for (var l: u32 = 0; l < 32; l++) { | ||
| 429 | let q_idx = q_b_idx + l; | ||
| 430 | let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); | ||
| 431 | let qs_val = (q_byte >> shift) & 0xF; | ||
| 432 | sum += (f32(qs_val) * dl - ml) * src1[src1_i]; | ||
| 433 | src1_i++; | ||
| 434 | } | ||
| 435 | } | ||
| 436 | } | ||
| 437 | return sum; | ||
| 438 | } | ||
| 439 | |||
| 440 | #enddecl(Q4_K) | ||
| 441 | |||
| 442 | #decl(Q5_K) | ||
| 443 | // 8 blocks of 32 elements each | ||
| 444 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 445 | let block = src0[src0_idx_base + offset]; | ||
| 446 | let d = f32(block.d); | ||
| 447 | let m = f32(block.dmin); | ||
| 448 | var sum = 0.0; | ||
| 449 | var src1_i = src1_idx_base + offset * 256; | ||
| 450 | var is: u32 = 0; | ||
| 451 | var u: u32 = 1; | ||
| 452 | // 2 blocks each iteration | ||
| 453 | for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { | ||
| 454 | for (var shift: u32 = 0; shift < 8; shift += 4) { | ||
| 455 | let scale_min = get_scale_min(is, block.scales); | ||
| 456 | is++; | ||
| 457 | let dl = d * scale_min.x; | ||
| 458 | let ml = m * scale_min.y; | ||
| 459 | for (var l: u32 = 0; l < 32; l++) { | ||
| 460 | let q_idx = q_b_idx + l; | ||
| 461 | let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); | ||
| 462 | let qh_byte = get_byte(block.qh[l / 4], l % 4); | ||
| 463 | let qs_val = (q_byte >> shift) & 0xF; | ||
| 464 | let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); | ||
| 465 | sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; | ||
| 466 | src1_i++; | ||
| 467 | } | ||
| 468 | u <<= 1; | ||
| 469 | } | ||
| 470 | } | ||
| 471 | return sum; | ||
| 472 | } | ||
| 473 | |||
| 474 | #enddecl(Q5_K) | ||
| 475 | |||
| 476 | #decl(Q6_K) | ||
| 477 | // 16 blocks of 16 elements each | ||
| 478 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 479 | let block = src0[src0_idx_base + offset]; | ||
| 480 | let d = f32(block.d); | ||
| 481 | |||
| 482 | // convert arrays of f16 -> u32 | ||
| 483 | var ql_vals: array<u32, 32>; | ||
| 484 | for (var i: u32 = 0; i < 32; i++) { | ||
| 485 | ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1])); | ||
| 486 | } | ||
| 487 | var qh_vals: array<u32, 16>; | ||
| 488 | for (var i: u32 = 0; i < 16; i++) { | ||
| 489 | qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1])); | ||
| 490 | } | ||
| 491 | var scale_vals: array<u32, 4>; | ||
| 492 | for (var i: u32 = 0; i < 4; i++) { | ||
| 493 | scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); | ||
| 494 | } | ||
| 495 | |||
| 496 | var sum = 0.0; | ||
| 497 | var src1_i = src1_idx_base + offset * 256; | ||
| 498 | var qh_b_idx: u32 = 0; | ||
| 499 | var sc_b_idx: u32 = 0; | ||
| 500 | for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { | ||
| 501 | for (var l: u32 = 0; l < 32; l++) { | ||
| 502 | let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); | ||
| 503 | let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); | ||
| 504 | let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); | ||
| 505 | |||
| 506 | let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; | ||
| 507 | let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; | ||
| 508 | let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; | ||
| 509 | let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; | ||
| 510 | |||
| 511 | let is = l/16; | ||
| 512 | let is1 = sc_b_idx + is; | ||
| 513 | let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); | ||
| 514 | let is2 = sc_b_idx + is + 2; | ||
| 515 | let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); | ||
| 516 | let is3 = sc_b_idx + is + 4; | ||
| 517 | let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); | ||
| 518 | let is4 = sc_b_idx + is + 6; | ||
| 519 | let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); | ||
| 520 | |||
| 521 | sum += d * f32(sc1) * q1 * src1[src1_i + l]; | ||
| 522 | sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; | ||
| 523 | sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; | ||
| 524 | sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; | ||
| 525 | } | ||
| 526 | src1_i += 128; | ||
| 527 | qh_b_idx += 32; | ||
| 528 | sc_b_idx += 8; | ||
| 529 | } | ||
| 530 | return sum; | ||
| 531 | } | ||
| 532 | |||
| 533 | #enddecl(Q6_K) | ||
| 534 | |||
| 535 | #decl(IQ2_XXS) | ||
| 536 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 537 | let block = src0[src0_idx_base + offset]; | ||
| 538 | let d = f32(block.d); | ||
| 539 | var src1_i = src1_idx_base + offset * 256; | ||
| 540 | var sum = 0.0; | ||
| 541 | for (var ib: u32 = 0; ib < 32; ib += 4) { | ||
| 542 | let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1])); | ||
| 543 | let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3])); | ||
| 544 | let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; | ||
| 545 | for (var l: u32 = 0; l < 4; l++) { | ||
| 546 | let ig = get_byte(aux0, l) * 8; | ||
| 547 | let is = (aux1 >> (7 * l)) & 127; | ||
| 548 | let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); | ||
| 549 | for (var j: u32 = 0; j < 8; j++) { | ||
| 550 | let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); | ||
| 551 | let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); | ||
| 552 | sum += db * f32(g) * m * src1[src1_i]; | ||
| 553 | src1_i++; | ||
| 554 | } | ||
| 555 | } | ||
| 556 | } | ||
| 557 | return sum; | ||
| 558 | } | ||
| 559 | |||
| 560 | #enddecl(IQ2_XXS) | ||
| 561 | |||
| 562 | #decl(IQ2_XS) | ||
| 563 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 564 | let block = src0[src0_idx_base + offset]; | ||
| 565 | let d = f32(block.d); | ||
| 566 | var src1_i = src1_idx_base + offset * 256; | ||
| 567 | var scale_vals = array<u32, 2>( | ||
| 568 | bitcast<u32>(vec2(block.scales[0], block.scales[1])), | ||
| 569 | bitcast<u32>(vec2(block.scales[2], block.scales[3])) | ||
| 570 | ); | ||
| 571 | var sum = 0.0; | ||
| 572 | for (var ib: u32 = 0; ib < 32; ib += 4) { | ||
| 573 | let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); | ||
| 574 | let db = array<f32, 2>( | ||
| 575 | d * (0.5 + f32(s & 0xF)) * 0.25, | ||
| 576 | d * (0.5 + f32(s >> 4)) * 0.25 | ||
| 577 | ); | ||
| 578 | for (var l: u32 = 0; l < 4; l++) { | ||
| 579 | let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0)); | ||
| 580 | let ig = (qs_val & 511) * 8; | ||
| 581 | let is = qs_val >> 9; | ||
| 582 | let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); | ||
| 583 | let dl = db[l/2]; | ||
| 584 | for (var j: u32 = 0; j < 8; j++) { | ||
| 585 | let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); | ||
| 586 | let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); | ||
| 587 | sum += dl * f32(g) * m * src1[src1_i]; | ||
| 588 | src1_i++; | ||
| 589 | } | ||
| 590 | } | ||
| 591 | } | ||
| 592 | return sum; | ||
| 593 | } | ||
| 594 | |||
| 595 | #enddecl(IQ2_XS) | ||
| 596 | |||
| 597 | #decl(IQ2_S) | ||
| 598 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 599 | let block = src0[src0_idx_base + offset]; | ||
| 600 | let d = f32(block.d); | ||
| 601 | var src1_i = src1_idx_base + offset * 256; | ||
| 602 | var qs_vals : array<u32, 16>; | ||
| 603 | for (var i: u32 = 0; i < 16; i++) { | ||
| 604 | qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); | ||
| 605 | } | ||
| 606 | var qh_vals = array<u32, 2>( | ||
| 607 | bitcast<u32>(vec2(block.qh[0], block.qh[1])), | ||
| 608 | bitcast<u32>(vec2(block.qh[2], block.qh[3])) | ||
| 609 | ); | ||
| 610 | var scale_vals = array<u32, 2>( | ||
| 611 | bitcast<u32>(vec2(block.scales[0], block.scales[1])), | ||
| 612 | bitcast<u32>(vec2(block.scales[2], block.scales[3])) | ||
| 613 | ); | ||
| 614 | var sum = 0.0; | ||
| 615 | for (var ib: u32 = 0; ib < 8; ib ++) { | ||
| 616 | let s = get_byte(scale_vals[ib / 4], ib % 4); | ||
| 617 | let db = array<f32, 2>( | ||
| 618 | d * (0.5 + f32(s & 0xF)) * 0.25, | ||
| 619 | d * (0.5 + f32(s >> 4)) * 0.25 | ||
| 620 | ); | ||
| 621 | let qs_w = qs_vals[ib]; | ||
| 622 | for (var l: u32 = 0; l < 4; l++) { | ||
| 623 | let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; | ||
| 624 | let ig = (get_byte(qs_w, l) | qh_b) * 8; | ||
| 625 | let signs = get_byte(qs_vals[ib + 8], l); | ||
| 626 | let dl = db[l/2]; | ||
| 627 | for (var j: u32 = 0; j < 8; j++) { | ||
| 628 | let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); | ||
| 629 | let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); | ||
| 630 | sum += dl * f32(g) * m * src1[src1_i]; | ||
| 631 | src1_i++; | ||
| 632 | } | ||
| 633 | } | ||
| 634 | } | ||
| 635 | return sum; | ||
| 636 | } | ||
| 637 | |||
| 638 | |||
| 639 | #enddecl(IQ2_S) | ||
| 640 | |||
| 641 | #decl(IQ3_XSS) | ||
| 642 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 643 | let block = src0[src0_idx_base + offset]; | ||
| 644 | let d = f32(block.d); | ||
| 645 | var src1_i = src1_idx_base + offset * 256; | ||
| 646 | var sum = 0.0; | ||
| 647 | for (var ib: u32 = 0; ib < 16; ib += 2) { | ||
| 648 | let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33])); | ||
| 649 | let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; | ||
| 650 | for (var l: u32 = 0; l < 4; l++) { | ||
| 651 | let is = (sc_sign >> (7 * l)) & 127; | ||
| 652 | let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); | ||
| 653 | let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0)); | ||
| 654 | let ig1 = get_byte(ig_val, 0); | ||
| 655 | let ig2 = get_byte(ig_val, 1); | ||
| 656 | for (var j: u32 = 0; j < 4; j++) { | ||
| 657 | let g1 = get_byte(iq3xxs_grid[ig1], j); | ||
| 658 | let g2 = get_byte(iq3xxs_grid[ig2], j); | ||
| 659 | let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); | ||
| 660 | let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); | ||
| 661 | sum += db * f32(g1) * m1 * src1[src1_i]; | ||
| 662 | sum += db * f32(g2) * m2 * src1[src1_i + 4]; | ||
| 663 | src1_i++; | ||
| 664 | } | ||
| 665 | src1_i += 4; | ||
| 666 | } | ||
| 667 | } | ||
| 668 | return sum; | ||
| 669 | } | ||
| 670 | |||
| 671 | #enddecl(IQ3_XSS) | ||
| 672 | |||
| 673 | #decl(IQ3_S) | ||
| 674 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 675 | let block = src0[src0_idx_base + offset]; | ||
| 676 | let d = f32(block.d); | ||
| 677 | var src1_i = src1_idx_base + offset * 256; | ||
| 678 | var qh_vals = array<u32, 2>( | ||
| 679 | bitcast<u32>(vec2(block.qh[0], block.qh[1])), | ||
| 680 | bitcast<u32>(vec2(block.qh[2], block.qh[3])) | ||
| 681 | ); | ||
| 682 | var sign_vals: array<u32, 8>; | ||
| 683 | for (var i: u32 = 0; i < 8; i++) { | ||
| 684 | sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); | ||
| 685 | } | ||
| 686 | var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1])); | ||
| 687 | var sum = 0.0; | ||
| 688 | for (var ib: u32 = 0; ib < 4; ib++) { | ||
| 689 | let s = get_byte(scale_vals, ib); | ||
| 690 | let db = array<f32, 2>( | ||
| 691 | d * (1.0 + 2.0 * f32(s & 0xF)), | ||
| 692 | d * (1.0 + 2.0 * f32(s >> 4)) | ||
| 693 | ); | ||
| 694 | for (var k: u32 = 0; k < 2; k++) { | ||
| 695 | let dl = db[k]; | ||
| 696 | let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); | ||
| 697 | let sign_w = sign_vals[ib * 2 + k]; | ||
| 698 | for (var l: u32 = 0; l < 4; l++) { | ||
| 699 | let signs = get_byte(sign_w, l); | ||
| 700 | let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); | ||
| 701 | let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); | ||
| 702 | let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); | ||
| 703 | for (var j: u32 = 0; j < 4; j++) { | ||
| 704 | let g1 = get_byte(iq3s_grid[ig1], j); | ||
| 705 | let g2 = get_byte(iq3s_grid[ig2], j); | ||
| 706 | let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); | ||
| 707 | let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); | ||
| 708 | sum += dl * f32(g1) * m1 * src1[src1_i]; | ||
| 709 | sum += dl * f32(g2) * m2 * src1[src1_i + 4]; | ||
| 710 | src1_i++; | ||
| 711 | } | ||
| 712 | src1_i += 4; | ||
| 713 | } | ||
| 714 | } | ||
| 715 | } | ||
| 716 | return sum; | ||
| 717 | } | ||
| 718 | #enddecl(IQ3_S) | ||
| 719 | |||
| 720 | #decl(IQ1_S) | ||
| 721 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 722 | let block = src0[src0_idx_base + offset]; | ||
| 723 | let d = f32(block.d); | ||
| 724 | var src1_i = src1_idx_base + offset * 256; | ||
| 725 | var sum = 0.0; | ||
| 726 | for (var ib: u32 = 0; ib < 8; ib++) { | ||
| 727 | let qh = bitcast<u32>(vec2(block.qh[ib], 0.0)); | ||
| 728 | let dl = d * (2 * f32((qh >> 12) & 7) + 1); | ||
| 729 | let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); | ||
| 730 | let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); | ||
| 731 | for (var l: u32 = 0; l < 4; l++) { | ||
| 732 | let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; | ||
| 733 | for (var j: u32 = 0; j < 8; j++) { | ||
| 734 | let gw = iq1_grid[(ig + j) / 16]; | ||
| 735 | let g = (gw >> (((ig + j) % 16) * 2)) & 3; | ||
| 736 | let gs = bitcast<i32>(g << 30) >> 30; | ||
| 737 | sum += dl * (f32(gs) + delta) * src1[src1_i]; | ||
| 738 | src1_i++; | ||
| 739 | } | ||
| 740 | } | ||
| 741 | } | ||
| 742 | return sum; | ||
| 743 | } | ||
| 744 | |||
| 745 | #enddecl(IQ1_S) | ||
| 746 | |||
| 747 | #decl(IQ1_M) | ||
| 748 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 749 | let block = src0[src0_idx_base + offset]; | ||
| 750 | |||
| 751 | let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); | ||
| 752 | let d = f32(bitcast<vec2<f16>>(scale).x); | ||
| 753 | var src1_i = src1_idx_base + offset * 256; | ||
| 754 | var sum = 0.0; | ||
| 755 | for (var ib: u32 = 0; ib < 8; ib++) { | ||
| 756 | let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; | ||
| 757 | let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; | ||
| 758 | let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; | ||
| 759 | var dl = array<f32, 2>( | ||
| 760 | d * f32(2 * s1 + 1), | ||
| 761 | d * f32(2 * s2 + 1) | ||
| 762 | ); | ||
| 763 | |||
| 764 | let qh = block.qh[ib / 2] >> (16 * (ib % 2)); | ||
| 765 | var idx = array<u32, 4>( | ||
| 766 | get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), | ||
| 767 | get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), | ||
| 768 | get_byte(block.qs[ib], 2) | ((qh) & 0x700), | ||
| 769 | get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) | ||
| 770 | ); | ||
| 771 | var delta = array<f32, 4>( | ||
| 772 | select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), | ||
| 773 | select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), | ||
| 774 | select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), | ||
| 775 | select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) | ||
| 776 | ); | ||
| 777 | for (var l: u32 = 0; l < 4; l++) { | ||
| 778 | let ig = idx[l] * 8; | ||
| 779 | for (var j: u32 = 0; j < 8; j++) { | ||
| 780 | let gw = iq1_grid[(ig + j) / 16]; | ||
| 781 | let g = (gw >> (((ig + j) % 16) * 2)) & 3; | ||
| 782 | let gs = bitcast<i32>(g << 30) >> 30; | ||
| 783 | sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; | ||
| 784 | src1_i++; | ||
| 785 | } | ||
| 786 | } | ||
| 787 | } | ||
| 788 | return sum; | ||
| 789 | } | ||
| 790 | |||
| 791 | #enddecl(IQ1_M) | ||
| 792 | |||
| 793 | #decl(IQ4_NL) | ||
| 794 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 795 | let block = src0[src0_idx_base + offset]; | ||
| 796 | let d = f32(block.d); | ||
| 797 | var src1_i = src1_idx_base + offset * 32; | ||
| 798 | var sum = 0.0; | ||
| 799 | var qs: array<u32, 4>; | ||
| 800 | for (var i: u32 = 0; i < 4; i++) { | ||
| 801 | qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); | ||
| 802 | } | ||
| 803 | for (var j: u32 = 0; j < 16; j++) { | ||
| 804 | let qsb = get_byte(qs[j / 4], j % 4); | ||
| 805 | sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; | ||
| 806 | sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; | ||
| 807 | src1_i++; | ||
| 808 | } | ||
| 809 | return sum; | ||
| 810 | } | ||
| 811 | |||
| 812 | #enddecl(IQ4_NL) | ||
| 813 | |||
| 814 | #decl(IQ4_XS) | ||
| 815 | fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { | ||
| 816 | let block = src0[src0_idx_base + offset]; | ||
| 817 | let d = f32(block.d); | ||
| 818 | let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0)); | ||
| 819 | var src1_i = src1_idx_base + offset * 256; | ||
| 820 | var sum = 0.0; | ||
| 821 | for (var ib: u32 = 0; ib < 8; ib++) { | ||
| 822 | let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); | ||
| 823 | let dl = d * (f32(ls) - 32.0); | ||
| 824 | for (var j: u32 = 0; j < 16; j++) { | ||
| 825 | let iqs = ib * 16 + j; | ||
| 826 | let qsb = get_byte(block.qs[iqs / 4], iqs % 4); | ||
| 827 | sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; | ||
| 828 | sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; | ||
| 829 | src1_i++; | ||
| 830 | } | ||
| 831 | src1_i += 16; | ||
| 832 | } | ||
| 833 | return sum; | ||
| 834 | } | ||
| 835 | |||
| 836 | #enddecl(IQ4_XS) | ||
| 837 | |||
| 838 | #end(DECLS) | ||
| 839 | |||
| 840 | #define(SHADER) | ||
| 841 | |||
| 842 | enable f16; | ||
| 843 | |||
| 844 | DECLS | ||
| 845 | |||
| 846 | struct MulMatParams { | ||
| 847 | offset_src0: u32, // in elements/blocks | ||
| 848 | offset_src1: u32, // in elements/blocks | ||
| 849 | offset_dst: u32, // in elements/blocks | ||
| 850 | m: u32, | ||
| 851 | n: u32, | ||
| 852 | k: u32, | ||
| 853 | // all strides are in elements/blocks | ||
| 854 | stride_01: u32, | ||
| 855 | stride_11: u32, | ||
| 856 | stride_02: u32, | ||
| 857 | stride_12: u32, | ||
| 858 | stride_03: u32, | ||
| 859 | stride_13: u32, | ||
| 860 | |||
| 861 | bs02: u32, | ||
| 862 | bs03: u32, | ||
| 863 | broadcast2: u32, | ||
| 864 | broadcast3: u32 | ||
| 865 | }; | ||
| 866 | |||
| 867 | @group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns | ||
| 868 | @group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) | ||
| 869 | @group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns | ||
| 870 | |||
| 871 | @group(0) @binding(3) var<uniform> params: MulMatParams; | ||
| 872 | |||
| 873 | @compute @workgroup_size(256) | ||
| 874 | fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| 875 | let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; | ||
| 876 | if (global_id.x >= total) { | ||
| 877 | return; | ||
| 878 | } | ||
| 879 | |||
| 880 | let dst2_stride = params.m * params.n; | ||
| 881 | let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; | ||
| 882 | |||
| 883 | let dst3_idx = global_id.x / dst3_stride; | ||
| 884 | let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension | ||
| 885 | let src13_idx = dst3_idx; // src1 is not broadcast | ||
| 886 | let dst3_rem = global_id.x % dst3_stride; | ||
| 887 | |||
| 888 | let dst2_idx = dst3_rem / dst2_stride; | ||
| 889 | let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension | ||
| 890 | let src12_idx = dst2_idx; // src1 is not broadcast | ||
| 891 | |||
| 892 | let dst2_rem = dst3_rem % dst2_stride; | ||
| 893 | |||
| 894 | let row = dst2_rem / params.m; // output row | ||
| 895 | let col = dst2_rem % params.m; // output column | ||
| 896 | |||
| 897 | let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; | ||
| 898 | let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; | ||
| 899 | |||
| 900 | var sum = 0.0; | ||
| 901 | for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { | ||
| 902 | sum += multiply_add(src0_idx_base, src1_idx_base, i); | ||
| 903 | } | ||
| 904 | dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; | ||
| 905 | } | ||
| 906 | |||
| 907 | #end(SHADER) | ||
