aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cpu/arch/wasm
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/arch/wasm')
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c1221
1 files changed, 1221 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c b/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c
new file mode 100644
index 0000000..74a359e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c
@@ -0,0 +1,1221 @@
1#define GGML_COMMON_IMPL_C
2#include "ggml-common.h"
3#include "ggml-quants.h"
4#include "ggml-impl.h"
5#include "ggml-cpu.h"
6#include "simd-mappings.h"
7
8#include "../../quants.h"
9#include "../../ggml-cpu-impl.h"
10
11#include <math.h>
12#include <string.h>
13#include <assert.h>
14#include <float.h>
15#include <stdlib.h> // for qsort
16#include <stdio.h> // for GGML_ASSERT
17
18#define GROUP_MAX_EPS 1e-15f
19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20#define GROUP_MAX_EPS_IQ2_S 1e-8f
21#define GROUP_MAX_EPS_IQ1_M 1e-7f
22#define GROUP_MAX_EPS_IQ1_S 1e-12f
23
24#define UNUSED GGML_UNUSED
25
26#if defined(__wasm_simd128__)
27#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
28#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
29#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
30#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
31#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
32#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
33#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
34#define B8(c,s ) B7(c,s, c), B7(c,s, s)
35
36// precomputed tables for expanding 8bits to 8 bytes:
37static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
38static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
39#endif
40
41void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
42 assert(QK8_0 == 32);
43 assert(k % QK8_0 == 0);
44 const int nb = k / QK8_0;
45
46 block_q8_0 * GGML_RESTRICT y = vy;
47
48#if defined __wasm_simd128__
49 for (int i = 0; i < nb; i++) {
50 v128_t srcv [8];
51 v128_t asrcv[8];
52 v128_t amaxv[8];
53
54 for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
55 for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
56
57 for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
58 for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
59 for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
60
61 const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
62 wasm_f32x4_extract_lane(amaxv[0], 1)),
63 MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
64 wasm_f32x4_extract_lane(amaxv[0], 3)));
65
66 const float d = amax / ((1 << 7) - 1);
67 const float id = d ? 1.0f/d : 0.0f;
68
69 y[i].d = GGML_CPU_FP32_TO_FP16(d);
70
71 for (int j = 0; j < 8; j++) {
72 const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
73 const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
74
75 y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
76 y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
77 y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
78 y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
79 }
80 }
81#else
82 GGML_UNUSED(nb);
83 // scalar
84 quantize_row_q8_0_ref(x, y, k);
85#endif
86}
87
88void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
89 assert(k % QK8_1 == 0);
90 const int nb = k / QK8_1;
91
92 block_q8_1 * GGML_RESTRICT y = vy;
93#if defined __wasm_simd128__
94 for (int i = 0; i < nb; i++) {
95 v128_t srcv [8];
96 v128_t asrcv[8];
97 v128_t amaxv[8];
98
99 for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
100 for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
101
102 for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
103 for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
104 for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
105
106 const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
107 wasm_f32x4_extract_lane(amaxv[0], 1)),
108 MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
109 wasm_f32x4_extract_lane(amaxv[0], 3)));
110
111 const float d = amax / ((1 << 7) - 1);
112 const float id = d ? 1.0f/d : 0.0f;
113
114 y[i].d = GGML_CPU_FP32_TO_FP16(d);
115
116 v128_t accv = wasm_i32x4_splat(0);
117
118 for (int j = 0; j < 8; j++) {
119 const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
120 const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
121
122 y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
123 y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
124 y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
125 y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
126
127 accv = wasm_i32x4_add(accv, vi);
128 }
129
130 y[i].s = GGML_CPU_FP32_TO_FP16(
131 d * (wasm_i32x4_extract_lane(accv, 0) +
132 wasm_i32x4_extract_lane(accv, 1) +
133 wasm_i32x4_extract_lane(accv, 2) +
134 wasm_i32x4_extract_lane(accv, 3)));
135 }
136#else
137 GGML_UNUSED(nb);
138 // scalar
139 quantize_row_q8_1_ref(x, y, k);
140#endif
141}
142
143//===================================== Q8_K ==============================================
144
145void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
146#ifdef __wasm_simd128__
147 assert(k % QK_K == 0);
148 const int64_t nb = k / QK_K;
149 block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type
150
151 for (int i = 0; i < nb; i++) {
152 const float * x_block = x + i * QK_K;
153
154 v128_t min_vec = wasm_v128_load(x_block);
155 v128_t max_vec = min_vec;
156
157 for (int j = 4; j < QK_K; j += 4) {
158 v128_t x_vec = wasm_v128_load(x_block + j);
159 max_vec = wasm_f32x4_pmax(max_vec, x_vec);
160 min_vec = wasm_f32x4_pmin(min_vec, x_vec);
161 }
162 max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
163 max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
164 min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
165 min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
166 float max = wasm_f32x4_extract_lane(max_vec, 0);
167 float min = wasm_f32x4_extract_lane(min_vec, 0);
168 float amax = -min > max ? min : max;
169
170 if (amax == 0.0f) {
171 yc[i].d = 0.0f;
172 const v128_t zero = wasm_i8x16_splat(0);
173 for (int j = 0; j < QK_K; j += 16) {
174 wasm_v128_store(yc[i].qs + j, zero);
175 }
176 continue;
177 }
178
179 const float iscale = -127.0f / amax;
180 const v128_t scale_vec = wasm_f32x4_splat(iscale);
181
182 // Process 16 elements per iteration
183 for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
184 // Load and quantize 16 floats
185 v128_t x0 = wasm_v128_load(x_block + j);
186 v128_t x1 = wasm_v128_load(x_block + j + 4);
187 v128_t x2 = wasm_v128_load(x_block + j + 8);
188 v128_t x3 = wasm_v128_load(x_block + j + 12);
189
190 v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
191 v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
192 v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
193 v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
194
195 // Convert to i32 with saturation
196 v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
197 v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
198 v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
199 v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
200
201 // Pack into 16 i8 values
202 v128_t i8 = wasm_i8x16_narrow_i16x8(
203 wasm_i16x8_narrow_i32x4(i0, i1),
204 wasm_i16x8_narrow_i32x4(i2, i3)
205 );
206 wasm_v128_store(yc[i].qs + j, i8);
207
208 // Calculate bsums using SIMD
209 v128_t sum16 = wasm_i16x8_add(
210 wasm_i16x8_extend_low_i8x16(i8),
211 wasm_i16x8_extend_high_i8x16(i8)
212 );
213 v128_t sum32 = wasm_i32x4_add(
214 wasm_i32x4_extend_low_i16x8(sum16),
215 wasm_i32x4_extend_high_i16x8(sum16)
216 );
217 sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
218 sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
219 yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
220 }
221
222 yc[i].d = 1.0f / iscale;
223 }
224#else
225 quantize_row_q8_K_ref(x, y, k);
226#endif
227}
228
229
230//===================================== Dot products =================================
231
232void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
233 const int qk = QK8_0;
234 const int nb = n / qk;
235
236 assert(n % qk == 0);
237 assert(nrc == 1);
238 UNUSED(nrc);
239 UNUSED(bx);
240 UNUSED(by);
241 UNUSED(bs);
242
243 const block_q4_0 * GGML_RESTRICT x = vx;
244 const block_q8_0 * GGML_RESTRICT y = vy;
245
246 int ib = 0;
247 float sumf = 0;
248
249#if defined __wasm_simd128__
250 v128_t sumv = wasm_f32x4_splat(0.0f);
251
252 const v128_t m4b = wasm_i8x16_splat(0x0F);
253 const v128_t s8b = wasm_i8x16_splat(0x8);
254
255 for (; ib + 1 < nb; ib += 2) {
256 const block_q4_0 * GGML_RESTRICT x0 = &x[ib];
257 const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
258 const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
259 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
260
261 // Load and process x0
262 v128_t v0_0 = wasm_v128_load(x0->qs);
263 v128_t v0_0l = wasm_v128_and(v0_0, m4b);
264 v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
265 v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
266 v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
267
268 // Load y0 vectors
269 v128_t y0_l = wasm_v128_load(y0->qs);
270 v128_t y0_h = wasm_v128_load(y0->qs + 16);
271
272 // Extend to i16x8 and compute dot products
273 v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
274 v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
275 v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
276 v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
277
278 v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
279 v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
280 v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
281 v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
282
283 v128_t dp0 = wasm_i32x4_add(
284 wasm_i32x4_add(
285 wasm_i32x4_dot_i16x8(dx0l, dy0ll),
286 wasm_i32x4_dot_i16x8(dx0h, dy0lh)
287 ),
288 wasm_i32x4_add(
289 wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
290 wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
291 )
292 );
293
294 // Load and process x1
295 v128_t v0_1 = wasm_v128_load(x1->qs);
296 v128_t v0_1l = wasm_v128_and(v0_1, m4b);
297 v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
298 v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
299 v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
300
301 // Load y1 vectors
302 v128_t y1_l = wasm_v128_load(y1->qs);
303 v128_t y1_h = wasm_v128_load(y1->qs + 16);
304
305 // Extend to i16x8 and compute dot products
306 v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
307 v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
308 v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
309 v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
310
311 v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
312 v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
313 v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
314 v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
315
316 v128_t dp1 = wasm_i32x4_add(
317 wasm_i32x4_add(
318 wasm_i32x4_dot_i16x8(dx1l, dy1ll),
319 wasm_i32x4_dot_i16x8(dx1h, dy1lh)
320 ),
321 wasm_i32x4_add(
322 wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
323 wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
324 )
325 );
326
327 // Accumulate results with scaling
328 float scale0 = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);
329 float scale1 = GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d);
330
331 sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
332 sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
333 }
334
335 sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
336 wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
337
338#endif
339 for (; ib < nb; ++ib) {
340 int sumi0 = 0;
341 int sumi1 = 0;
342
343 for (int j = 0; j < qk/2; ++j) {
344 const int v0 = (x[ib].qs[j] & 0x0F) - 8;
345 const int v1 = (x[ib].qs[j] >> 4) - 8;
346
347 sumi0 += (v0 * y[ib].qs[j]);
348 sumi1 += (v1 * y[ib].qs[j + qk/2]);
349 }
350
351 int sumi = sumi0 + sumi1;
352 sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
353 }
354
355 *s = sumf;
356}
357
358void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
359 const int qk = QK8_0;
360 const int nb = n / qk;
361
362 int ib = 0;
363 float sumf = 0;
364
365 assert(n % qk == 0);
366 assert(qk == QK5_0);
367 assert(nrc == 1);
368 UNUSED(nrc);
369 UNUSED(bx);
370 UNUSED(by);
371 UNUSED(bs);
372
373 const block_q5_0 * GGML_RESTRICT x = vx;
374 const block_q8_0 * GGML_RESTRICT y = vy;
375
376#if defined __wasm_simd128__
377 v128_t sumv = wasm_f32x4_splat(0.0f);
378
379 uint32_t qh_;
380 uint64_t tmp[4];
381
382 // TODO: check if unrolling this is better
383 for (; ib < nb; ++ib) {
384 const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
385 const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
386
387 const v128_t m4b = wasm_i8x16_splat(0x0F);
388
389 // extract the 5th bit
390 memcpy(&qh_, x0->qh, sizeof(qh_));
391
392 tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
393 tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
394 tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
395 tmp[3] = table_b2b_1[(qh_ >> 24) ];
396
397 const v128_t qhl = wasm_v128_load(tmp + 0);
398 const v128_t qhh = wasm_v128_load(tmp + 2);
399
400 const v128_t v0 = wasm_v128_load(x0->qs);
401
402 // 4-bit -> 8-bit
403 const v128_t v0l = wasm_v128_and (v0, m4b);
404 const v128_t v0h = wasm_u8x16_shr(v0, 4);
405
406 // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
407 const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
408 const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
409
410 // load y
411 const v128_t v1l = wasm_v128_load(y0->qs);
412 const v128_t v1h = wasm_v128_load(y0->qs + 16);
413
414 // int8x16 -> int16x8
415 const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
416 const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
417 const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
418 const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
419
420 const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
421 const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
422 const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
423 const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
424
425 // dot product
426 sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
427 wasm_i32x4_add(
428 wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
429 wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
430 wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
431 wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
432 wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
433 }
434
435 sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
436 wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
437
438 *s = sumf;
439#else
440 UNUSED(nb);
441 UNUSED(ib);
442 UNUSED(sumf);
443 UNUSED(x);
444 UNUSED(y);
445 ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
446#endif
447}
448
449void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
450 const int qk = QK8_1;
451 const int nb = n / qk;
452
453 int ib = 0;
454 float sumf = 0;
455
456 assert(n % qk == 0);
457 assert(qk == QK5_1);
458 assert(nrc == 1);
459 UNUSED(nrc);
460 UNUSED(bx);
461 UNUSED(by);
462 UNUSED(bs);
463
464 const block_q5_1 * GGML_RESTRICT x = vx;
465 const block_q8_1 * GGML_RESTRICT y = vy;
466
467#if defined __wasm_simd128__
468 v128_t sumv = wasm_f32x4_splat(0.0f);
469
470 float summs = 0.0f;
471
472 uint32_t qh_;
473 uint64_t tmp[4];
474
475 // TODO: check if unrolling this is better
476 for (; ib < nb; ++ib) {
477 const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
478 const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
479
480 summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
481
482 const v128_t m4b = wasm_i8x16_splat(0x0F);
483
484 // extract the 5th bit
485 memcpy(&qh_, x0->qh, sizeof(qh_));
486
487 tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
488 tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
489 tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
490 tmp[3] = table_b2b_0[(qh_ >> 24) ];
491
492 const v128_t qhl = wasm_v128_load(tmp + 0);
493 const v128_t qhh = wasm_v128_load(tmp + 2);
494
495 const v128_t v0 = wasm_v128_load(x0->qs);
496
497 // 4-bit -> 8-bit
498 const v128_t v0l = wasm_v128_and (v0, m4b);
499 const v128_t v0h = wasm_u8x16_shr(v0, 4);
500
501 // add high bit
502 const v128_t v0lf = wasm_v128_or(v0l, qhl);
503 const v128_t v0hf = wasm_v128_or(v0h, qhh);
504
505 // load y
506 const v128_t v1l = wasm_v128_load(y0->qs);
507 const v128_t v1h = wasm_v128_load(y0->qs + 16);
508
509 // int8x16 -> int16x8
510 const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
511 const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
512 const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
513 const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
514
515 const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
516 const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
517 const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
518 const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
519
520 // dot product
521 sumv = wasm_f32x4_add(sumv,
522 wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
523 wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
524 wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
525 wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
526 wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
527 wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
528 }
529
530 sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
531 wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
532
533 *s = sumf;
534#else
535 UNUSED(nb);
536 UNUSED(ib);
537 UNUSED(sumf);
538 UNUSED(x);
539 UNUSED(y);
540 ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
541#endif
542}
543
544void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
545 const int qk = QK8_0;
546 const int nb = n / qk;
547
548 assert(n % qk == 0);
549 assert(nrc == 1);
550 UNUSED(nrc);
551 UNUSED(bx);
552 UNUSED(by);
553 UNUSED(bs);
554
555 const block_q8_0 * GGML_RESTRICT x = vx;
556 const block_q8_0 * GGML_RESTRICT y = vy;
557
558 int ib = 0;
559 float sumf = 0;
560
561#if defined __wasm_simd128__
562 v128_t sumv = wasm_f32x4_splat(0.0f);
563
564 for (; ib < nb; ++ib) {
565 const block_q8_0 * GGML_RESTRICT x0 = &x[ib];
566 const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
567
568 const v128_t x0_0 = wasm_v128_load(x0->qs);
569 const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
570 const v128_t y0_0 = wasm_v128_load(y0->qs);
571 const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
572
573 // Extend 8-bit to 16-bit
574 const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
575 const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
576 const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
577 const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
578
579 const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
580 const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
581 const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
582 const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
583
584 // Compute dot products
585 const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
586 const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
587 const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
588 const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
589
590 // Sum all dot products
591 const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
592
593 // Convert to float and accumulate
594 const float scale = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);
595 sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
596 }
597
598 sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
599 wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
600
601 *s = sumf;
602#else
603 UNUSED(nb);
604 UNUSED(x);
605 UNUSED(y);
606 UNUSED(ib);
607 UNUSED(sumf);
608 ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
609#endif
610}
611
612void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
613 assert(nrc == 1);
614 UNUSED(nrc);
615 UNUSED(bx);
616 UNUSED(by);
617 UNUSED(bs);
618
619 const block_q2_K * GGML_RESTRICT x = vx;
620 const block_q8_K * GGML_RESTRICT y = vy;
621
622 const int nb = n / QK_K;
623
624#if defined __wasm_simd128__
625 float sumf = 0;
626
627 for (int i = 0; i < nb; ++i) {
628 const uint8_t * q2 = x[i].qs;
629 const int8_t * q8 = y[i].qs;
630 const uint8_t * sc = x[i].scales;
631
632 // Vectorized summs calculation
633 v128_t summs_vec = wasm_i32x4_splat(0);
634 {
635 v128_t sc_vec = wasm_v128_load(sc);
636 v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
637
638 v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
639 v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
640
641 v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
642 v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
643
644 summs_vec = wasm_i32x4_add(
645 wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
646 wasm_i32x4_dot_i16x8(sc_high, bsums2)),
647 summs_vec
648 );
649
650 summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
651 summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
652 }
653 int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
654
655 // Vectorized isum calculation
656 int32_t isum = 0;
657 const uint8_t * sc_ptr = sc;
658 const int k_iters = QK_K/128;
659
660 for (int k = 0; k < k_iters; ++k) {
661 v128_t isum_vec = wasm_i32x4_splat(0);
662 int shift = 0;
663
664 for (int j = 0; j < 4; ++j) {
665 const int d0 = (sc_ptr[0] & 0xF);
666 const int d1 = (sc_ptr[1] & 0xF);
667 sc_ptr += 2;
668
669 // Process first 16 elements
670 v128_t q2_0 = wasm_v128_load(q2);
671 v128_t q8_0 = wasm_v128_load(q8);
672 v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
673 v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
674
675 // Process next 16 elements
676 v128_t q2_1 = wasm_v128_load(q2 + 16);
677 v128_t q8_1 = wasm_v128_load(q8 + 16);
678 v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
679 v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
680
681 // Calculate dot products
682 v128_t p0 = wasm_i32x4_dot_i16x8(
683 wasm_i16x8_extend_low_i8x16(q8_0),
684 wasm_i16x8_extend_low_i8x16(q2_bits_0)
685 );
686 v128_t p1 = wasm_i32x4_dot_i16x8(
687 wasm_i16x8_extend_high_i8x16(q8_0),
688 wasm_i16x8_extend_high_i8x16(q2_bits_0)
689 );
690 v128_t p2 = wasm_i32x4_dot_i16x8(
691 wasm_i16x8_extend_low_i8x16(q8_1),
692 wasm_i16x8_extend_low_i8x16(q2_bits_1)
693 );
694 v128_t p3 = wasm_i32x4_dot_i16x8(
695 wasm_i16x8_extend_high_i8x16(q8_1),
696 wasm_i16x8_extend_high_i8x16(q2_bits_1)
697 );
698
699 // Accumulate scaled results
700 v128_t scaled = wasm_i32x4_add(
701 wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
702 wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
703 );
704
705 isum_vec = wasm_i32x4_add(isum_vec, scaled);
706 q8 += 32;
707 shift += 2;
708 }
709 q2 += 32;
710
711 // Horizontal sum of isum_vec
712 isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
713 isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
714 isum += wasm_i32x4_extract_lane(isum_vec, 0);
715 }
716
717 const float dall = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
718 const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
719 sumf += dall * isum - dmin * summs;
720 }
721
722 *s = sumf;
723
724#else
725 UNUSED(x);
726 UNUSED(y);
727 UNUSED(nb);
728 ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
729#endif
730}
731
732void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
733 assert(n % QK_K == 0);
734 assert(nrc == 1);
735 UNUSED(nrc);
736 UNUSED(bx);
737 UNUSED(by);
738 UNUSED(bs);
739
740 const uint32_t kmask1 = 0x03030303;
741 const uint32_t kmask2 = 0x0f0f0f0f;
742
743 const block_q3_K * GGML_RESTRICT x = vx;
744 const block_q8_K * GGML_RESTRICT y = vy;
745
746 const int nb = n / QK_K;
747
748#if defined __wasm_simd128__
749 int8_t aux8[QK_K];
750 float sums[8] = {0};
751 uint32_t auxs[4];
752
753 float sumf = 0;
754 for (int i = 0; i < nb; ++i) {
755 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
756 const uint8_t * GGML_RESTRICT hm = x[i].hmask;
757 const int8_t * GGML_RESTRICT q8 = y[i].qs;
758
759 // Process blocks with SIMD
760 int8_t * a = aux8;
761 uint8_t m = 1;
762 for (int j = 0; j < QK_K; j += 128) {
763 for (int shift = 0; shift <= 6; shift += 2) {
764 v128_t v_m = wasm_i8x16_splat(m);
765 for (int l = 0; l < 32; l += 16) {
766 v128_t v_q3 = wasm_v128_load(q3 + l);
767 v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
768 v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
769
770 v128_t v_hm = wasm_v128_load(hm + l);
771 v128_t v_mask = wasm_v128_and(v_hm, v_m);
772 v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
773
774 v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
775 wasm_v128_store(a + l, v_low2);
776 }
777 a += 32;
778 m <<= 1;
779 }
780 q3 += 32;
781 }
782
783 // Extract scales
784 memcpy(auxs, x[i].scales, 12);
785 uint32_t tmp = auxs[2];
786 auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
787 auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
788 auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
789 auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
790 const int8_t * scales = (const int8_t *)auxs;
791
792 // SIMD dot product with register accumulators
793 v128_t v_acc0 = wasm_i32x4_splat(0);
794 v128_t v_acc1 = wasm_i32x4_splat(0);
795 a = aux8;
796 for (int j = 0; j < QK_K/16; ++j) {
797 const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
798
799 // Process 16 elements per iteration
800 for (int k = 0; k < 2; ++k) {
801 const v128_t v_q8 = wasm_i16x8_load8x8(q8);
802 const v128_t v_a = wasm_i16x8_load8x8(a);
803
804 v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
805 v_prod = wasm_i16x8_mul(v_prod, v_scale);
806
807 v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
808 v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
809
810 q8 += 8;
811 a += 8;
812 }
813 }
814
815 // Accumulate results
816 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
817 const v128_t v_d = wasm_f32x4_splat(d);
818 v128_t v_sum = wasm_f32x4_add(
819 wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
820 wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
821 );
822
823 // Accumulate into sums vector
824 wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
825 }
826
827 // Horizontal sum
828 v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
829 sumf = wasm_f32x4_extract_lane(v_sum, 0) +
830 wasm_f32x4_extract_lane(v_sum, 1) +
831 wasm_f32x4_extract_lane(v_sum, 2) +
832 wasm_f32x4_extract_lane(v_sum, 3);
833
834 *s = sumf;
835
836#else
837 UNUSED(kmask1);
838 UNUSED(kmask2);
839 UNUSED(x);
840 UNUSED(y);
841 UNUSED(nb);
842 ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
843#endif
844
845}
846
847void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
848 assert(n % QK_K == 0);
849 assert(nrc == 1);
850 UNUSED(nrc);
851 UNUSED(bx);
852 UNUSED(by);
853 UNUSED(bs);
854
855 const block_q4_K * GGML_RESTRICT x = vx;
856 const block_q8_K * GGML_RESTRICT y = vy;
857
858 const int nb = n / QK_K;
859
860 static const uint32_t kmask1 = 0x3f3f3f3f;
861 static const uint32_t kmask2 = 0x0f0f0f0f;
862 static const uint32_t kmask3 = 0x03030303;
863
864 uint32_t utmp[4];
865
866#if defined __wasm_simd128__
867 const uint8_t * scales = (const uint8_t*)&utmp[0];
868 float sumf = 0;
869
870 for (int i = 0; i < nb; ++i) {
871 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
872 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Corrected sign
873
874 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
875 const int8_t * GGML_RESTRICT q8 = y[i].qs;
876
877 // Process scales and mins
878 memcpy(utmp, x[i].scales, 12);
879 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
880 const uint32_t uaux = utmp[1] & kmask1;
881 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
882 utmp[2] = uaux;
883 utmp[0] &= kmask1;
884
885 // Sum mins * q8sums
886 int32_t sumi = 0;
887 const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
888 const uint8_t * m = (const uint8_t *)&utmp[2];
889 for (int j = 0; j < 16; j += 2) {
890 sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
891 }
892 sumf -= dmin * sumi;
893
894 int32_t sumi1 = 0;
895 int32_t sumi2 = 0;
896
897 for (int j = 0; j < QK_K/64; ++j) {
898 // Load 64 4-bit weights (32 bytes)
899 const v128_t q4x0 = wasm_v128_load(q4);
900 const v128_t q4x1 = wasm_v128_load(q4 + 16);
901 q4 += 32;
902
903 // Split into low/high nibbles
904 const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
905 const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
906 const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
907 const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
908
909 // Load 64 8-bit values (64 bytes)
910 const v128_t q8x0 = wasm_v128_load(q8);
911 const v128_t q8x1 = wasm_v128_load(q8 + 16);
912 const v128_t q8x2 = wasm_v128_load(q8 + 32);
913 const v128_t q8x3 = wasm_v128_load(q8 + 48);
914 q8 += 64;
915
916 // Low nibble products
917 v128_t vacc1 = wasm_i32x4_dot_i16x8(
918 wasm_i16x8_extend_low_i8x16(q4l0),
919 wasm_i16x8_extend_low_i8x16(q8x0)
920 );
921 vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
922 wasm_i16x8_extend_high_i8x16(q4l0),
923 wasm_i16x8_extend_high_i8x16(q8x0)
924 ));
925 vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
926 wasm_i16x8_extend_low_i8x16(q4l1),
927 wasm_i16x8_extend_low_i8x16(q8x1)
928 ));
929 vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
930 wasm_i16x8_extend_high_i8x16(q4l1),
931 wasm_i16x8_extend_high_i8x16(q8x1)
932 ));
933
934 // High nibble products
935 v128_t vacc2 = wasm_i32x4_dot_i16x8(
936 wasm_i16x8_extend_low_i8x16(q4h0),
937 wasm_i16x8_extend_low_i8x16(q8x2)
938 );
939 vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
940 wasm_i16x8_extend_high_i8x16(q4h0),
941 wasm_i16x8_extend_high_i8x16(q8x2)
942 ));
943 vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
944 wasm_i16x8_extend_low_i8x16(q4h1),
945 wasm_i16x8_extend_low_i8x16(q8x3)
946 ));
947 vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
948 wasm_i16x8_extend_high_i8x16(q4h1),
949 wasm_i16x8_extend_high_i8x16(q8x3)
950 ));
951
952 // Accumulate scaled results
953 int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
954 wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
955 sumi1 += vacc1_sum * scales[2*j];
956
957 int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
958 wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
959 sumi2 += vacc2_sum * scales[2*j+1];
960 }
961
962 sumf += d * (sumi1 + sumi2);
963 }
964
965 *s = sumf;
966
967#else
968 UNUSED(x);
969 UNUSED(y);
970 UNUSED(nb);
971 UNUSED(kmask1);
972 UNUSED(kmask2);
973 UNUSED(kmask3);
974 UNUSED(utmp);
975 ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
976#endif
977}
978
979void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
980 assert(n % QK_K == 0);
981 assert(nrc == 1);
982 UNUSED(nrc);
983 UNUSED(bx);
984 UNUSED(by);
985 UNUSED(bs);
986
987 const block_q5_K * GGML_RESTRICT x = vx;
988 const block_q8_K * GGML_RESTRICT y = vy;
989
990 const int nb = n / QK_K;
991
992 static const uint32_t kmask1 = 0x3f3f3f3f;
993 static const uint32_t kmask2 = 0x0f0f0f0f;
994 static const uint32_t kmask3 = 0x03030303;
995
996 uint32_t utmp[4];
997
998#if defined __wasm_simd128__
999 //const uint8_t * scales = (const uint8_t*)&utmp[0];
1000 float sumf = 0;
1001
1002 for (int i = 0; i < nb; ++i) {
1003 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1004 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Fixed sign
1005
1006 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1007 const uint8_t * GGML_RESTRICT qh = x[i].qh;
1008 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1009
1010 // Process scales and mins
1011 memcpy(utmp, x[i].scales, 12);
1012 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1013 const uint32_t uaux = utmp[1] & kmask1;
1014 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1015 utmp[2] = uaux;
1016 utmp[0] &= kmask1;
1017
1018 // Sum mins * q8sums
1019 int32_t sumi_mins = 0;
1020 const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
1021 const uint8_t * m = (const uint8_t *)&utmp[2];
1022 for (int j = 0; j < 16; j += 2) {
1023 sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
1024 }
1025 sumf -= dmin * sumi_mins; // Correct subtraction
1026
1027 v128_t qh0 = wasm_v128_load(qh);
1028 v128_t qh1 = wasm_v128_load(qh + 16);
1029 const uint8_t * sc = (const uint8_t *)utmp;
1030
1031 int32_t sumi = 0;
1032
1033 for (int j = 0; j < QK_K/64; ++j) {
1034 const int shift = j * 2;
1035 v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
1036 v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
1037
1038 v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
1039 v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
1040 v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
1041 v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
1042
1043 v128_t q5_0 = wasm_v128_load(q5);
1044 v128_t q5_1 = wasm_v128_load(q5 + 16);
1045 q5 += 32;
1046
1047 v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
1048 v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
1049 v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
1050 v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
1051
1052 v128_t q8_0 = wasm_v128_load(q8);
1053 v128_t q8_1 = wasm_v128_load(q8 + 16);
1054 v128_t q8_2 = wasm_v128_load(q8 + 32);
1055 v128_t q8_3 = wasm_v128_load(q8 + 48);
1056 q8 += 64;
1057
1058 // Process low quants
1059 v128_t pl0 = wasm_i32x4_dot_i16x8(
1060 wasm_i16x8_extend_low_i8x16(q5l_0),
1061 wasm_i16x8_extend_low_i8x16(q8_0)
1062 );
1063 pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
1064 wasm_i16x8_extend_high_i8x16(q5l_0),
1065 wasm_i16x8_extend_high_i8x16(q8_0)
1066 ));
1067 v128_t pl1 = wasm_i32x4_dot_i16x8(
1068 wasm_i16x8_extend_low_i8x16(q5l_1),
1069 wasm_i16x8_extend_low_i8x16(q8_1)
1070 );
1071 pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
1072 wasm_i16x8_extend_high_i8x16(q5l_1),
1073 wasm_i16x8_extend_high_i8x16(q8_1)
1074 ));
1075 v128_t sum_low = wasm_i32x4_add(pl0, pl1);
1076
1077 // Process high quants
1078 v128_t ph0 = wasm_i32x4_dot_i16x8(
1079 wasm_i16x8_extend_low_i8x16(q5h_0),
1080 wasm_i16x8_extend_low_i8x16(q8_2)
1081 );
1082 ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
1083 wasm_i16x8_extend_high_i8x16(q5h_0),
1084 wasm_i16x8_extend_high_i8x16(q8_2)
1085 ));
1086 v128_t ph1 = wasm_i32x4_dot_i16x8(
1087 wasm_i16x8_extend_low_i8x16(q5h_1),
1088 wasm_i16x8_extend_low_i8x16(q8_3)
1089 );
1090 ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
1091 wasm_i16x8_extend_high_i8x16(q5h_1),
1092 wasm_i16x8_extend_high_i8x16(q8_3)
1093 ));
1094 v128_t sum_high = wasm_i32x4_add(ph0, ph1);
1095
1096 // Accumulate with scale factors
1097 int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
1098 wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
1099 int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
1100 wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
1101
1102 sumi += sl * sc[2*j] + sh * sc[2*j+1];
1103 }
1104
1105 sumf += d * sumi;
1106 }
1107
1108 *s = sumf;
1109
1110#else
1111 UNUSED(x);
1112 UNUSED(y);
1113 UNUSED(nb);
1114 UNUSED(kmask1);
1115 UNUSED(kmask2);
1116 UNUSED(kmask3);
1117 UNUSED(utmp);
1118 ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1119#endif
1120}
1121
1122void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1123 assert(n % QK_K == 0);
1124 assert(nrc == 1);
1125 UNUSED(nrc);
1126 UNUSED(bx);
1127 UNUSED(by);
1128 UNUSED(bs);
1129
1130 const block_q6_K * GGML_RESTRICT x = vx;
1131 const block_q8_K * GGML_RESTRICT y = vy;
1132
1133 const int nb = n / QK_K;
1134
1135#if defined __wasm_simd128__
1136 int8_t aux8[QK_K] __attribute__((aligned(16)));
1137 int32_t aux32[8] __attribute__((aligned(16))) = {0};
1138 float sums[8] __attribute__((aligned(16))) = {0};
1139
1140 for (int i = 0; i < nb; ++i) {
1141 // Unpack 6-bit quantized data into aux8 (unchanged)
1142 const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1143 const uint8_t * GGML_RESTRICT qh = x[i].qh;
1144 int8_t * a = aux8;
1145 for (int j = 0; j < QK_K; j += 128) {
1146 for (int l = 0; l < 32; ++l) {
1147 a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1148 a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1149 a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1150 a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1151 }
1152 a += 128;
1153 q4 += 64;
1154 qh += 32;
1155 }
1156
1157 const int8_t * GGML_RESTRICT a_ptr = aux8;
1158 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1159 v128_t acc0 = wasm_i32x4_splat(0);
1160 v128_t acc1 = wasm_i32x4_splat(0);
1161
1162 for (int j = 0; j < QK_K/16; ++j) {
1163 const int scale = x[i].scales[j];
1164 const v128_t vscale = wasm_i32x4_splat(scale);
1165
1166 // Load 16 elements from a and q8
1167 const v128_t a_vec = wasm_v128_load(a_ptr);
1168 const v128_t q8_vec = wasm_v128_load(q8);
1169
1170 // Process low 8 elements
1171 v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
1172 v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
1173 v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
1174 v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
1175 v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
1176
1177 // Process high 8 elements
1178 v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
1179 v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
1180 v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
1181 v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
1182 v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
1183
1184 // Scale and accumulate
1185 prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
1186 prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
1187 prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
1188 prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
1189
1190 acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
1191 acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
1192
1193 a_ptr += 16;
1194 q8 += 16;
1195 }
1196
1197 // Store accumulated results
1198 wasm_v128_store(&aux32[0], acc0);
1199 wasm_v128_store(&aux32[4], acc1);
1200
1201 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1202 for (int l = 0; l < 8; ++l) {
1203 sums[l] += d * aux32[l];
1204 }
1205 }
1206
1207 // Sum final results
1208 float sumf = 0;
1209 for (int l = 0; l < 8; ++l) {
1210 sumf += sums[l];
1211 }
1212 *s = sumf;
1213
1214#else
1215 UNUSED(x);
1216 UNUSED(y);
1217 UNUSED(nb);
1218 ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1219#endif
1220}
1221