1#define GGML_COMMON_IMPL_C
   2#include "ggml-common.h"
   3#include "ggml-quants.h"
   4#include "ggml-impl.h"
   5#include "ggml-cpu.h"
   6#include "simd-mappings.h"
   7
   8#include "../../quants.h"
   9#include "../../ggml-cpu-impl.h"
  10
  11#include <math.h>
  12#include <string.h>
  13#include <assert.h>
  14#include <float.h>
  15#include <stdlib.h> // for qsort
  16#include <stdio.h>  // for GGML_ASSERT
  17
  18#define GROUP_MAX_EPS 1e-15f
  19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
  20#define GROUP_MAX_EPS_IQ2_S 1e-8f
  21#define GROUP_MAX_EPS_IQ1_M 1e-7f
  22#define GROUP_MAX_EPS_IQ1_S 1e-12f
  23
  24#define UNUSED GGML_UNUSED
  25
  26#if defined(__wasm_simd128__)
  27#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
  28#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
  29#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
  30#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
  31#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
  32#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
  33#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
  34#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
  35
  36// precomputed tables for expanding 8bits to 8 bytes:
  37static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
  38static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
  39#endif
  40
  41void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  42    assert(QK8_0 == 32);
  43    assert(k % QK8_0 == 0);
  44    const int nb = k / QK8_0;
  45
  46    block_q8_0 * GGML_RESTRICT y = vy;
  47
  48#if defined __wasm_simd128__
  49    for (int i = 0; i < nb; i++) {
  50        v128_t srcv [8];
  51        v128_t asrcv[8];
  52        v128_t amaxv[8];
  53
  54        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
  55        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
  56
  57        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
  58        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
  59        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
  60
  61        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
  62                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
  63                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
  64                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
  65
  66        const float d = amax / ((1 << 7) - 1);
  67        const float id = d ? 1.0f/d : 0.0f;
  68
  69        y[i].d = GGML_CPU_FP32_TO_FP16(d);
  70
  71        for (int j = 0; j < 8; j++) {
  72            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
  73            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
  74
  75            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
  76            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
  77            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
  78            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
  79        }
  80    }
  81#else
  82    GGML_UNUSED(nb);
  83    // scalar
  84    quantize_row_q8_0_ref(x, y, k);
  85#endif
  86}
  87
  88void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  89    assert(k % QK8_1 == 0);
  90    const int nb = k / QK8_1;
  91
  92    block_q8_1 * GGML_RESTRICT y = vy;
  93#if defined __wasm_simd128__
  94    for (int i = 0; i < nb; i++) {
  95        v128_t srcv [8];
  96        v128_t asrcv[8];
  97        v128_t amaxv[8];
  98
  99        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
 100        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
 101
 102        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
 103        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
 104        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
 105
 106        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
 107                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
 108                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
 109                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
 110
 111        const float d = amax / ((1 << 7) - 1);
 112        const float id = d ? 1.0f/d : 0.0f;
 113
 114        y[i].d = GGML_CPU_FP32_TO_FP16(d);
 115
 116        v128_t accv = wasm_i32x4_splat(0);
 117
 118        for (int j = 0; j < 8; j++) {
 119            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
 120            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
 121
 122            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
 123            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
 124            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
 125            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
 126
 127            accv = wasm_i32x4_add(accv, vi);
 128        }
 129
 130        y[i].s = GGML_CPU_FP32_TO_FP16(
 131                d * (wasm_i32x4_extract_lane(accv, 0) +
 132                     wasm_i32x4_extract_lane(accv, 1) +
 133                     wasm_i32x4_extract_lane(accv, 2) +
 134                     wasm_i32x4_extract_lane(accv, 3)));
 135    }
 136#else
 137    GGML_UNUSED(nb);
 138    // scalar
 139    quantize_row_q8_1_ref(x, y, k);
 140#endif
 141}
 142
 143//===================================== Q8_K ==============================================
 144
 145void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
 146#ifdef __wasm_simd128__
 147    assert(k % QK_K == 0);
 148    const int64_t nb = k / QK_K;
 149    block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type
 150
 151    for (int i = 0; i < nb; i++) {
 152        const float * x_block = x + i * QK_K;
 153
 154        v128_t min_vec = wasm_v128_load(x_block);
 155        v128_t max_vec = min_vec;
 156
 157        for (int j = 4; j < QK_K; j += 4) {
 158            v128_t x_vec = wasm_v128_load(x_block + j);
 159            max_vec = wasm_f32x4_pmax(max_vec, x_vec);
 160            min_vec = wasm_f32x4_pmin(min_vec, x_vec);
 161        }
 162        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
 163        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
 164        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
 165        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
 166        float max = wasm_f32x4_extract_lane(max_vec, 0);
 167        float min = wasm_f32x4_extract_lane(min_vec, 0);
 168        float amax = -min > max ? min : max;
 169
 170        if (amax == 0.0f) {
 171            yc[i].d = 0.0f;
 172            const v128_t zero = wasm_i8x16_splat(0);
 173            for (int j = 0; j < QK_K; j += 16) {
 174                wasm_v128_store(yc[i].qs + j, zero);
 175            }
 176            continue;
 177        }
 178
 179        const float iscale = -127.0f / amax;
 180        const v128_t scale_vec = wasm_f32x4_splat(iscale);
 181
 182        // Process 16 elements per iteration
 183        for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
 184            // Load and quantize 16 floats
 185            v128_t x0 = wasm_v128_load(x_block + j);
 186            v128_t x1 = wasm_v128_load(x_block + j + 4);
 187            v128_t x2 = wasm_v128_load(x_block + j + 8);
 188            v128_t x3 = wasm_v128_load(x_block + j + 12);
 189
 190            v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
 191            v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
 192            v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
 193            v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
 194
 195            // Convert to i32 with saturation
 196            v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
 197            v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
 198            v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
 199            v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
 200
 201            // Pack into 16 i8 values
 202            v128_t i8 = wasm_i8x16_narrow_i16x8(
 203                wasm_i16x8_narrow_i32x4(i0, i1),
 204                wasm_i16x8_narrow_i32x4(i2, i3)
 205            );
 206            wasm_v128_store(yc[i].qs + j, i8);
 207
 208            // Calculate bsums using SIMD
 209            v128_t sum16 = wasm_i16x8_add(
 210                wasm_i16x8_extend_low_i8x16(i8),
 211                wasm_i16x8_extend_high_i8x16(i8)
 212            );
 213            v128_t sum32 = wasm_i32x4_add(
 214                wasm_i32x4_extend_low_i16x8(sum16),
 215                wasm_i32x4_extend_high_i16x8(sum16)
 216            );
 217            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
 218            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
 219            yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
 220        }
 221
 222        yc[i].d = 1.0f / iscale;
 223    }
 224#else
 225    quantize_row_q8_K_ref(x, y, k);
 226#endif
 227}
 228
 229
 230//===================================== Dot products =================================
 231
 232void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 233    const int qk = QK8_0;
 234    const int nb = n / qk;
 235
 236    assert(n % qk == 0);
 237    assert(nrc == 1);
 238    UNUSED(nrc);
 239    UNUSED(bx);
 240    UNUSED(by);
 241    UNUSED(bs);
 242
 243    const block_q4_0 * GGML_RESTRICT x = vx;
 244    const block_q8_0 * GGML_RESTRICT y = vy;
 245
 246    int ib = 0;
 247    float sumf = 0;
 248
 249#if defined __wasm_simd128__
 250    v128_t sumv = wasm_f32x4_splat(0.0f);
 251
 252    const v128_t m4b = wasm_i8x16_splat(0x0F);
 253    const v128_t s8b = wasm_i8x16_splat(0x8);
 254
 255    for (; ib + 1 < nb; ib += 2) {
 256        const block_q4_0 * GGML_RESTRICT x0 = &x[ib];
 257        const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
 258        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
 259        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 260
 261        // Load and process x0
 262        v128_t v0_0 = wasm_v128_load(x0->qs);
 263        v128_t v0_0l = wasm_v128_and(v0_0, m4b);
 264        v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
 265        v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
 266        v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
 267
 268        // Load y0 vectors
 269        v128_t y0_l = wasm_v128_load(y0->qs);
 270        v128_t y0_h = wasm_v128_load(y0->qs + 16);
 271
 272        // Extend to i16x8 and compute dot products
 273        v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
 274        v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
 275        v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
 276        v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
 277
 278        v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
 279        v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
 280        v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
 281        v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
 282
 283        v128_t dp0 = wasm_i32x4_add(
 284            wasm_i32x4_add(
 285                wasm_i32x4_dot_i16x8(dx0l, dy0ll),
 286                wasm_i32x4_dot_i16x8(dx0h, dy0lh)
 287            ),
 288            wasm_i32x4_add(
 289                wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
 290                wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
 291            )
 292        );
 293
 294        // Load and process x1
 295        v128_t v0_1 = wasm_v128_load(x1->qs);
 296        v128_t v0_1l = wasm_v128_and(v0_1, m4b);
 297        v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
 298        v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
 299        v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
 300
 301        // Load y1 vectors
 302        v128_t y1_l = wasm_v128_load(y1->qs);
 303        v128_t y1_h = wasm_v128_load(y1->qs + 16);
 304
 305        // Extend to i16x8 and compute dot products
 306        v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
 307        v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
 308        v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
 309        v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
 310
 311        v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
 312        v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
 313        v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
 314        v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
 315
 316        v128_t dp1 = wasm_i32x4_add(
 317            wasm_i32x4_add(
 318                wasm_i32x4_dot_i16x8(dx1l, dy1ll),
 319                wasm_i32x4_dot_i16x8(dx1h, dy1lh)
 320            ),
 321            wasm_i32x4_add(
 322                wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
 323                wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
 324            )
 325        );
 326
 327        // Accumulate results with scaling
 328        float scale0 = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);
 329        float scale1 = GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d);
 330
 331        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
 332        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
 333    }
 334
 335    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
 336           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 337
 338#endif
 339    for (; ib < nb; ++ib) {
 340        int sumi0 = 0;
 341        int sumi1 = 0;
 342
 343        for (int j = 0; j < qk/2; ++j) {
 344            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
 345            const int v1 = (x[ib].qs[j] >>   4) - 8;
 346
 347            sumi0 += (v0 * y[ib].qs[j]);
 348            sumi1 += (v1 * y[ib].qs[j + qk/2]);
 349        }
 350
 351        int sumi = sumi0 + sumi1;
 352        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
 353    }
 354
 355    *s = sumf;
 356}
 357
 358void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 359    const int qk = QK8_0;
 360    const int nb = n / qk;
 361
 362    int ib = 0;
 363    float sumf = 0;
 364
 365    assert(n % qk == 0);
 366    assert(qk == QK5_0);
 367    assert(nrc == 1);
 368    UNUSED(nrc);
 369    UNUSED(bx);
 370    UNUSED(by);
 371    UNUSED(bs);
 372
 373    const block_q5_0 * GGML_RESTRICT x = vx;
 374    const block_q8_0 * GGML_RESTRICT y = vy;
 375
 376#if defined __wasm_simd128__
 377    v128_t sumv = wasm_f32x4_splat(0.0f);
 378
 379    uint32_t qh_;
 380    uint64_t tmp[4];
 381
 382    // TODO: check if unrolling this is better
 383    for (; ib < nb; ++ib) {
 384        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
 385        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
 386
 387        const v128_t m4b  = wasm_i8x16_splat(0x0F);
 388
 389        // extract the 5th bit
 390        memcpy(&qh_, x0->qh, sizeof(qh_));
 391
 392        tmp[0] = table_b2b_1[(qh_ >>  0) & 0xFF];
 393        tmp[1] = table_b2b_1[(qh_ >>  8) & 0xFF];
 394        tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
 395        tmp[3] = table_b2b_1[(qh_ >> 24)       ];
 396
 397        const v128_t qhl = wasm_v128_load(tmp + 0);
 398        const v128_t qhh = wasm_v128_load(tmp + 2);
 399
 400        const v128_t v0 = wasm_v128_load(x0->qs);
 401
 402        // 4-bit -> 8-bit
 403        const v128_t v0l = wasm_v128_and (v0, m4b);
 404        const v128_t v0h = wasm_u8x16_shr(v0, 4);
 405
 406        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
 407        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
 408        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
 409
 410        // load y
 411        const v128_t v1l = wasm_v128_load(y0->qs);
 412        const v128_t v1h = wasm_v128_load(y0->qs + 16);
 413
 414        // int8x16 -> int16x8
 415        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
 416        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
 417        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
 418        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
 419
 420        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
 421        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
 422        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
 423        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
 424
 425        // dot product
 426        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
 427                        wasm_i32x4_add(
 428                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
 429                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
 430                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
 431                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
 432                    wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
 433    }
 434
 435    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
 436           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 437
 438    *s = sumf;
 439#else
 440    UNUSED(nb);
 441    UNUSED(ib);
 442    UNUSED(sumf);
 443    UNUSED(x);
 444    UNUSED(y);
 445    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 446#endif
 447}
 448
 449void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 450    const int qk = QK8_1;
 451    const int nb = n / qk;
 452
 453    int ib = 0;
 454    float sumf = 0;
 455
 456    assert(n % qk == 0);
 457    assert(qk == QK5_1);
 458    assert(nrc == 1);
 459    UNUSED(nrc);
 460    UNUSED(bx);
 461    UNUSED(by);
 462    UNUSED(bs);
 463
 464    const block_q5_1 * GGML_RESTRICT x = vx;
 465    const block_q8_1 * GGML_RESTRICT y = vy;
 466
 467#if defined __wasm_simd128__
 468    v128_t sumv = wasm_f32x4_splat(0.0f);
 469
 470    float summs = 0.0f;
 471
 472    uint32_t qh_;
 473    uint64_t tmp[4];
 474
 475    // TODO: check if unrolling this is better
 476    for (; ib < nb; ++ib) {
 477        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
 478        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
 479
 480        summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
 481
 482        const v128_t m4b = wasm_i8x16_splat(0x0F);
 483
 484        // extract the 5th bit
 485        memcpy(&qh_, x0->qh, sizeof(qh_));
 486
 487        tmp[0] = table_b2b_0[(qh_ >>  0) & 0xFF];
 488        tmp[1] = table_b2b_0[(qh_ >>  8) & 0xFF];
 489        tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
 490        tmp[3] = table_b2b_0[(qh_ >> 24)       ];
 491
 492        const v128_t qhl = wasm_v128_load(tmp + 0);
 493        const v128_t qhh = wasm_v128_load(tmp + 2);
 494
 495        const v128_t v0 = wasm_v128_load(x0->qs);
 496
 497        // 4-bit -> 8-bit
 498        const v128_t v0l = wasm_v128_and (v0, m4b);
 499        const v128_t v0h = wasm_u8x16_shr(v0, 4);
 500
 501        // add high bit
 502        const v128_t v0lf = wasm_v128_or(v0l, qhl);
 503        const v128_t v0hf = wasm_v128_or(v0h, qhh);
 504
 505        // load y
 506        const v128_t v1l = wasm_v128_load(y0->qs);
 507        const v128_t v1h = wasm_v128_load(y0->qs + 16);
 508
 509        // int8x16 -> int16x8
 510        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
 511        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
 512        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
 513        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
 514
 515        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
 516        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
 517        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
 518        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
 519
 520        // dot product
 521        sumv = wasm_f32x4_add(sumv,
 522                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
 523                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
 524                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
 525                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
 526                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
 527                    wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
 528    }
 529
 530    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
 531           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
 532
 533    *s = sumf;
 534#else
 535    UNUSED(nb);
 536    UNUSED(ib);
 537    UNUSED(sumf);
 538    UNUSED(x);
 539    UNUSED(y);
 540    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 541#endif
 542}
 543
 544void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 545    const int qk = QK8_0;
 546    const int nb = n / qk;
 547
 548    assert(n % qk == 0);
 549    assert(nrc == 1);
 550    UNUSED(nrc);
 551    UNUSED(bx);
 552    UNUSED(by);
 553    UNUSED(bs);
 554
 555    const block_q8_0 * GGML_RESTRICT x = vx;
 556    const block_q8_0 * GGML_RESTRICT y = vy;
 557
 558    int ib = 0;
 559    float sumf = 0;
 560
 561#if defined __wasm_simd128__
 562    v128_t sumv = wasm_f32x4_splat(0.0f);
 563
 564    for (; ib < nb; ++ib) {
 565        const block_q8_0 * GGML_RESTRICT x0 = &x[ib];
 566        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
 567
 568        const v128_t x0_0 = wasm_v128_load(x0->qs);
 569        const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
 570        const v128_t y0_0 = wasm_v128_load(y0->qs);
 571        const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
 572
 573        // Extend 8-bit to 16-bit
 574        const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
 575        const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
 576        const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
 577        const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
 578
 579        const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
 580        const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
 581        const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
 582        const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
 583
 584        // Compute dot products
 585        const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
 586        const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
 587        const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
 588        const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
 589
 590        // Sum all dot products
 591        const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
 592
 593        // Convert to float and accumulate
 594        const float scale = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);
 595        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
 596    }
 597
 598    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
 599           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 600
 601    *s = sumf;
 602#else
 603    UNUSED(nb);
 604    UNUSED(x);
 605    UNUSED(y);
 606    UNUSED(ib);
 607    UNUSED(sumf);
 608    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 609#endif
 610}
 611
 612void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 613    assert(nrc == 1);
 614    UNUSED(nrc);
 615    UNUSED(bx);
 616    UNUSED(by);
 617    UNUSED(bs);
 618
 619    const block_q2_K * GGML_RESTRICT x = vx;
 620    const block_q8_K * GGML_RESTRICT y = vy;
 621
 622    const int nb = n / QK_K;
 623
 624#if defined __wasm_simd128__
 625    float sumf = 0;
 626
 627    for (int i = 0; i < nb; ++i) {
 628        const uint8_t * q2 = x[i].qs;
 629        const int8_t * q8 = y[i].qs;
 630        const uint8_t * sc = x[i].scales;
 631
 632        // Vectorized summs calculation
 633        v128_t summs_vec = wasm_i32x4_splat(0);
 634        {
 635            v128_t sc_vec = wasm_v128_load(sc);
 636            v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
 637
 638            v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
 639            v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
 640
 641            v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
 642            v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
 643
 644            summs_vec = wasm_i32x4_add(
 645                wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
 646                               wasm_i32x4_dot_i16x8(sc_high, bsums2)),
 647                summs_vec
 648            );
 649
 650            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
 651            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
 652        }
 653        int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
 654
 655        // Vectorized isum calculation
 656        int32_t isum = 0;
 657        const uint8_t * sc_ptr = sc;
 658        const int k_iters = QK_K/128;
 659
 660        for (int k = 0; k < k_iters; ++k) {
 661            v128_t isum_vec = wasm_i32x4_splat(0);
 662            int shift = 0;
 663
 664            for (int j = 0; j < 4; ++j) {
 665                const int d0 = (sc_ptr[0] & 0xF);
 666                const int d1 = (sc_ptr[1] & 0xF);
 667                sc_ptr += 2;
 668
 669                // Process first 16 elements
 670                v128_t q2_0 = wasm_v128_load(q2);
 671                v128_t q8_0 = wasm_v128_load(q8);
 672                v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
 673                v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
 674
 675                // Process next 16 elements
 676                v128_t q2_1 = wasm_v128_load(q2 + 16);
 677                v128_t q8_1 = wasm_v128_load(q8 + 16);
 678                v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
 679                v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
 680
 681                // Calculate dot products
 682                v128_t p0 = wasm_i32x4_dot_i16x8(
 683                    wasm_i16x8_extend_low_i8x16(q8_0),
 684                    wasm_i16x8_extend_low_i8x16(q2_bits_0)
 685                );
 686                v128_t p1 = wasm_i32x4_dot_i16x8(
 687                    wasm_i16x8_extend_high_i8x16(q8_0),
 688                    wasm_i16x8_extend_high_i8x16(q2_bits_0)
 689                );
 690                v128_t p2 = wasm_i32x4_dot_i16x8(
 691                    wasm_i16x8_extend_low_i8x16(q8_1),
 692                    wasm_i16x8_extend_low_i8x16(q2_bits_1)
 693                );
 694                v128_t p3 = wasm_i32x4_dot_i16x8(
 695                    wasm_i16x8_extend_high_i8x16(q8_1),
 696                    wasm_i16x8_extend_high_i8x16(q2_bits_1)
 697                );
 698
 699                // Accumulate scaled results
 700                v128_t scaled = wasm_i32x4_add(
 701                    wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
 702                    wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
 703                );
 704
 705                isum_vec = wasm_i32x4_add(isum_vec, scaled);
 706                q8 += 32;
 707                shift += 2;
 708            }
 709            q2 += 32;
 710
 711            // Horizontal sum of isum_vec
 712            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
 713            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
 714            isum += wasm_i32x4_extract_lane(isum_vec, 0);
 715        }
 716
 717        const float dall = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
 718        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
 719        sumf += dall * isum - dmin * summs;
 720    }
 721
 722    *s = sumf;
 723
 724#else
 725    UNUSED(x);
 726    UNUSED(y);
 727    UNUSED(nb);
 728    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 729#endif
 730}
 731
 732void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 733    assert(n % QK_K == 0);
 734    assert(nrc == 1);
 735    UNUSED(nrc);
 736    UNUSED(bx);
 737    UNUSED(by);
 738    UNUSED(bs);
 739
 740    const uint32_t kmask1 = 0x03030303;
 741    const uint32_t kmask2 = 0x0f0f0f0f;
 742
 743    const block_q3_K * GGML_RESTRICT x = vx;
 744    const block_q8_K * GGML_RESTRICT y = vy;
 745
 746    const int nb = n / QK_K;
 747
 748#if defined __wasm_simd128__
 749    int8_t  aux8[QK_K];
 750    float   sums[8] = {0};
 751    uint32_t auxs[4];
 752
 753    float sumf = 0;
 754    for (int i = 0; i < nb; ++i) {
 755        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
 756        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
 757        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 758
 759        // Process blocks with SIMD
 760        int8_t * a = aux8;
 761        uint8_t m = 1;
 762        for (int j = 0; j < QK_K; j += 128) {
 763            for (int shift = 0; shift <= 6; shift += 2) {
 764                v128_t v_m = wasm_i8x16_splat(m);
 765                for (int l = 0; l < 32; l += 16) {
 766                    v128_t v_q3 = wasm_v128_load(q3 + l);
 767                    v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
 768                    v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
 769
 770                    v128_t v_hm = wasm_v128_load(hm + l);
 771                    v128_t v_mask = wasm_v128_and(v_hm, v_m);
 772                    v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
 773
 774                    v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
 775                    wasm_v128_store(a + l, v_low2);
 776                }
 777                a += 32;
 778                m <<= 1;
 779            }
 780            q3 += 32;
 781        }
 782
 783        // Extract scales
 784        memcpy(auxs, x[i].scales, 12);
 785        uint32_t tmp = auxs[2];
 786        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
 787        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
 788        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
 789        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
 790        const int8_t * scales = (const int8_t *)auxs;
 791
 792        // SIMD dot product with register accumulators
 793        v128_t v_acc0 = wasm_i32x4_splat(0);
 794        v128_t v_acc1 = wasm_i32x4_splat(0);
 795        a = aux8;
 796        for (int j = 0; j < QK_K/16; ++j) {
 797            const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
 798
 799            // Process 16 elements per iteration
 800            for (int k = 0; k < 2; ++k) {
 801                const v128_t v_q8 = wasm_i16x8_load8x8(q8);
 802                const v128_t v_a = wasm_i16x8_load8x8(a);
 803
 804                v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
 805                v_prod = wasm_i16x8_mul(v_prod, v_scale);
 806
 807                v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
 808                v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
 809
 810                q8 += 8;
 811                a += 8;
 812            }
 813        }
 814
 815        // Accumulate results
 816        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
 817        const v128_t v_d = wasm_f32x4_splat(d);
 818        v128_t v_sum = wasm_f32x4_add(
 819            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
 820            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
 821        );
 822
 823        // Accumulate into sums vector
 824        wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
 825    }
 826
 827    // Horizontal sum
 828    v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
 829    sumf = wasm_f32x4_extract_lane(v_sum, 0) +
 830           wasm_f32x4_extract_lane(v_sum, 1) +
 831           wasm_f32x4_extract_lane(v_sum, 2) +
 832           wasm_f32x4_extract_lane(v_sum, 3);
 833
 834    *s = sumf;
 835
 836#else
 837    UNUSED(kmask1);
 838    UNUSED(kmask2);
 839    UNUSED(x);
 840    UNUSED(y);
 841    UNUSED(nb);
 842    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 843#endif
 844
 845}
 846
 847void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 848    assert(n % QK_K == 0);
 849    assert(nrc == 1);
 850    UNUSED(nrc);
 851    UNUSED(bx);
 852    UNUSED(by);
 853    UNUSED(bs);
 854
 855    const block_q4_K * GGML_RESTRICT x = vx;
 856    const block_q8_K * GGML_RESTRICT y = vy;
 857
 858    const int nb = n / QK_K;
 859
 860    static const uint32_t kmask1 = 0x3f3f3f3f;
 861    static const uint32_t kmask2 = 0x0f0f0f0f;
 862    static const uint32_t kmask3 = 0x03030303;
 863
 864    uint32_t utmp[4];
 865
 866#if defined __wasm_simd128__
 867    const uint8_t * scales = (const uint8_t*)&utmp[0];
 868    float sumf = 0;
 869
 870    for (int i = 0; i < nb; ++i) {
 871        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
 872        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Corrected sign
 873
 874        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
 875        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
 876
 877        // Process scales and mins
 878        memcpy(utmp, x[i].scales, 12);
 879        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
 880        const uint32_t uaux = utmp[1] & kmask1;
 881        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
 882        utmp[2] = uaux;
 883        utmp[0] &= kmask1;
 884
 885        // Sum mins * q8sums
 886        int32_t sumi = 0;
 887        const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
 888        const uint8_t * m = (const uint8_t *)&utmp[2];
 889        for (int j = 0; j < 16; j += 2) {
 890            sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
 891        }
 892        sumf -= dmin * sumi;
 893
 894        int32_t sumi1 = 0;
 895        int32_t sumi2 = 0;
 896
 897        for (int j = 0; j < QK_K/64; ++j) {
 898            // Load 64 4-bit weights (32 bytes)
 899            const v128_t q4x0 = wasm_v128_load(q4);
 900            const v128_t q4x1 = wasm_v128_load(q4 + 16);
 901            q4 += 32;
 902
 903            // Split into low/high nibbles
 904            const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
 905            const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
 906            const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
 907            const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
 908
 909            // Load 64 8-bit values (64 bytes)
 910            const v128_t q8x0 = wasm_v128_load(q8);
 911            const v128_t q8x1 = wasm_v128_load(q8 + 16);
 912            const v128_t q8x2 = wasm_v128_load(q8 + 32);
 913            const v128_t q8x3 = wasm_v128_load(q8 + 48);
 914            q8 += 64;
 915
 916            // Low nibble products
 917            v128_t vacc1 = wasm_i32x4_dot_i16x8(
 918                wasm_i16x8_extend_low_i8x16(q4l0),
 919                wasm_i16x8_extend_low_i8x16(q8x0)
 920            );
 921            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
 922                wasm_i16x8_extend_high_i8x16(q4l0),
 923                wasm_i16x8_extend_high_i8x16(q8x0)
 924            ));
 925            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
 926                wasm_i16x8_extend_low_i8x16(q4l1),
 927                wasm_i16x8_extend_low_i8x16(q8x1)
 928            ));
 929            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
 930                wasm_i16x8_extend_high_i8x16(q4l1),
 931                wasm_i16x8_extend_high_i8x16(q8x1)
 932            ));
 933
 934            // High nibble products
 935            v128_t vacc2 = wasm_i32x4_dot_i16x8(
 936                wasm_i16x8_extend_low_i8x16(q4h0),
 937                wasm_i16x8_extend_low_i8x16(q8x2)
 938            );
 939            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
 940                wasm_i16x8_extend_high_i8x16(q4h0),
 941                wasm_i16x8_extend_high_i8x16(q8x2)
 942            ));
 943            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
 944                wasm_i16x8_extend_low_i8x16(q4h1),
 945                wasm_i16x8_extend_low_i8x16(q8x3)
 946            ));
 947            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
 948                wasm_i16x8_extend_high_i8x16(q4h1),
 949                wasm_i16x8_extend_high_i8x16(q8x3)
 950            ));
 951
 952            // Accumulate scaled results
 953            int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
 954                                wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
 955            sumi1 += vacc1_sum * scales[2*j];
 956
 957            int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
 958                                wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
 959            sumi2 += vacc2_sum * scales[2*j+1];
 960        }
 961
 962        sumf += d * (sumi1 + sumi2);
 963    }
 964
 965    *s = sumf;
 966
 967#else
 968    UNUSED(x);
 969    UNUSED(y);
 970    UNUSED(nb);
 971    UNUSED(kmask1);
 972    UNUSED(kmask2);
 973    UNUSED(kmask3);
 974    UNUSED(utmp);
 975    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 976#endif
 977}
 978
 979void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
 980    assert(n % QK_K == 0);
 981    assert(nrc == 1);
 982    UNUSED(nrc);
 983    UNUSED(bx);
 984    UNUSED(by);
 985    UNUSED(bs);
 986
 987    const block_q5_K * GGML_RESTRICT x = vx;
 988    const block_q8_K * GGML_RESTRICT y = vy;
 989
 990    const int nb = n / QK_K;
 991
 992    static const uint32_t kmask1 = 0x3f3f3f3f;
 993    static const uint32_t kmask2 = 0x0f0f0f0f;
 994    static const uint32_t kmask3 = 0x03030303;
 995
 996    uint32_t utmp[4];
 997
 998#if defined __wasm_simd128__
 999    //const uint8_t * scales = (const uint8_t*)&utmp[0];
1000    float sumf = 0;
1001
1002    for (int i = 0; i < nb; ++i) {
1003        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1004        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Fixed sign
1005
1006        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1007        const uint8_t * GGML_RESTRICT qh = x[i].qh;
1008        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1009
1010        // Process scales and mins
1011        memcpy(utmp, x[i].scales, 12);
1012        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1013        const uint32_t uaux = utmp[1] & kmask1;
1014        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1015        utmp[2] = uaux;
1016        utmp[0] &= kmask1;
1017
1018        // Sum mins * q8sums
1019        int32_t sumi_mins = 0;
1020        const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
1021        const uint8_t * m = (const uint8_t *)&utmp[2];
1022        for (int j = 0; j < 16; j += 2) {
1023            sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
1024        }
1025        sumf -= dmin * sumi_mins; // Correct subtraction
1026
1027        v128_t qh0 = wasm_v128_load(qh);
1028        v128_t qh1 = wasm_v128_load(qh + 16);
1029        const uint8_t * sc = (const uint8_t *)utmp;
1030
1031        int32_t sumi = 0;
1032
1033        for (int j = 0; j < QK_K/64; ++j) {
1034            const int shift = j * 2;
1035            v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
1036            v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
1037
1038            v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
1039            v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
1040            v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
1041            v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
1042
1043            v128_t q5_0 = wasm_v128_load(q5);
1044            v128_t q5_1 = wasm_v128_load(q5 + 16);
1045            q5 += 32;
1046
1047            v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
1048            v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
1049            v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
1050            v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
1051
1052            v128_t q8_0 = wasm_v128_load(q8);
1053            v128_t q8_1 = wasm_v128_load(q8 + 16);
1054            v128_t q8_2 = wasm_v128_load(q8 + 32);
1055            v128_t q8_3 = wasm_v128_load(q8 + 48);
1056            q8 += 64;
1057
1058            // Process low quants
1059            v128_t pl0 = wasm_i32x4_dot_i16x8(
1060                wasm_i16x8_extend_low_i8x16(q5l_0),
1061                wasm_i16x8_extend_low_i8x16(q8_0)
1062            );
1063            pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
1064                wasm_i16x8_extend_high_i8x16(q5l_0),
1065                wasm_i16x8_extend_high_i8x16(q8_0)
1066            ));
1067            v128_t pl1 = wasm_i32x4_dot_i16x8(
1068                wasm_i16x8_extend_low_i8x16(q5l_1),
1069                wasm_i16x8_extend_low_i8x16(q8_1)
1070            );
1071            pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
1072                wasm_i16x8_extend_high_i8x16(q5l_1),
1073                wasm_i16x8_extend_high_i8x16(q8_1)
1074            ));
1075            v128_t sum_low = wasm_i32x4_add(pl0, pl1);
1076
1077            // Process high quants
1078            v128_t ph0 = wasm_i32x4_dot_i16x8(
1079                wasm_i16x8_extend_low_i8x16(q5h_0),
1080                wasm_i16x8_extend_low_i8x16(q8_2)
1081            );
1082            ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
1083                wasm_i16x8_extend_high_i8x16(q5h_0),
1084                wasm_i16x8_extend_high_i8x16(q8_2)
1085            ));
1086            v128_t ph1 = wasm_i32x4_dot_i16x8(
1087                wasm_i16x8_extend_low_i8x16(q5h_1),
1088                wasm_i16x8_extend_low_i8x16(q8_3)
1089            );
1090            ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
1091                wasm_i16x8_extend_high_i8x16(q5h_1),
1092                wasm_i16x8_extend_high_i8x16(q8_3)
1093            ));
1094            v128_t sum_high = wasm_i32x4_add(ph0, ph1);
1095
1096            // Accumulate with scale factors
1097            int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
1098                        wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
1099            int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
1100                        wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
1101
1102            sumi += sl * sc[2*j] + sh * sc[2*j+1];
1103        }
1104
1105        sumf += d * sumi;
1106    }
1107
1108    *s = sumf;
1109
1110#else
1111    UNUSED(x);
1112    UNUSED(y);
1113    UNUSED(nb);
1114    UNUSED(kmask1);
1115    UNUSED(kmask2);
1116    UNUSED(kmask3);
1117    UNUSED(utmp);
1118    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1119#endif
1120}
1121
1122void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1123    assert(n % QK_K == 0);
1124    assert(nrc == 1);
1125    UNUSED(nrc);
1126    UNUSED(bx);
1127    UNUSED(by);
1128    UNUSED(bs);
1129
1130    const block_q6_K * GGML_RESTRICT x = vx;
1131    const block_q8_K * GGML_RESTRICT y = vy;
1132
1133    const int nb = n / QK_K;
1134
1135#if defined __wasm_simd128__
1136    int8_t aux8[QK_K] __attribute__((aligned(16)));
1137    int32_t aux32[8] __attribute__((aligned(16))) = {0};
1138    float sums[8] __attribute__((aligned(16))) = {0};
1139
1140    for (int i = 0; i < nb; ++i) {
1141        // Unpack 6-bit quantized data into aux8 (unchanged)
1142        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1143        const uint8_t * GGML_RESTRICT qh = x[i].qh;
1144        int8_t * a = aux8;
1145        for (int j = 0; j < QK_K; j += 128) {
1146            for (int l = 0; l < 32; ++l) {
1147                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1148                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1149                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1150                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1151            }
1152            a += 128;
1153            q4 += 64;
1154            qh += 32;
1155        }
1156
1157        const int8_t * GGML_RESTRICT a_ptr = aux8;
1158        const int8_t * GGML_RESTRICT q8 = y[i].qs;
1159        v128_t acc0 = wasm_i32x4_splat(0);
1160        v128_t acc1 = wasm_i32x4_splat(0);
1161
1162        for (int j = 0; j < QK_K/16; ++j) {
1163            const int scale = x[i].scales[j];
1164            const v128_t vscale = wasm_i32x4_splat(scale);
1165
1166            // Load 16 elements from a and q8
1167            const v128_t a_vec = wasm_v128_load(a_ptr);
1168            const v128_t q8_vec = wasm_v128_load(q8);
1169
1170            // Process low 8 elements
1171            v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
1172            v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
1173            v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
1174            v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
1175            v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
1176
1177            // Process high 8 elements
1178            v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
1179            v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
1180            v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
1181            v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
1182            v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
1183
1184            // Scale and accumulate
1185            prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
1186            prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
1187            prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
1188            prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
1189
1190            acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
1191            acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
1192
1193            a_ptr += 16;
1194            q8 += 16;
1195        }
1196
1197        // Store accumulated results
1198        wasm_v128_store(&aux32[0], acc0);
1199        wasm_v128_store(&aux32[4], acc1);
1200
1201        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1202        for (int l = 0; l < 8; ++l) {
1203            sums[l] += d * aux32[l];
1204        }
1205    }
1206
1207    // Sum final results
1208    float sumf = 0;
1209    for (int l = 0; l < 8; ++l) {
1210        sumf += sums[l];
1211    }
1212    *s = sumf;
1213
1214#else
1215    UNUSED(x);
1216    UNUSED(y);
1217    UNUSED(nb);
1218    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1219#endif
1220}
1221