1#define GGML_COMMON_IMPL_C
   2#include "ggml-common.h"
   3#include "ggml-quants.h"
   4#include "ggml-impl.h"
   5#include "ggml-cpu.h"
   6#include "simd-mappings.h"
   7
   8#include "../../quants.h"
   9#include "../../ggml-cpu-impl.h"
  10
  11#include <math.h>
  12#include <string.h>
  13#include <assert.h>
  14#include <float.h>
  15#include <stdlib.h> // for qsort
  16#include <stdio.h>  // for GGML_ASSERT
  17
  18#define GROUP_MAX_EPS 1e-15f
  19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
  20#define GROUP_MAX_EPS_IQ2_S 1e-8f
  21#define GROUP_MAX_EPS_IQ1_M 1e-7f
  22#define GROUP_MAX_EPS_IQ1_S 1e-12f
  23
  24#define UNUSED GGML_UNUSED
  25
  26#if defined(__ARM_NEON)
  27#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
  28#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
  29#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
  30#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
  31#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
  32#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
  33#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
  34#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
  35
  36// precomputed tables for expanding 8bits to 8 bytes:
  37static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
  38static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
  39#endif
  40
  41void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  42    assert(QK8_0 == 32);
  43    assert(k % QK8_0 == 0);
  44    const int nb = k / QK8_0;
  45
  46    block_q8_0 * GGML_RESTRICT y = vy;
  47
  48#if defined(__ARM_NEON)
  49    for (int i = 0; i < nb; i++) {
  50        float32x4_t srcv [8];
  51        float32x4_t asrcv[8];
  52        float32x4_t amaxv[8];
  53
  54        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
  55        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
  56
  57        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
  58        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
  59        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
  60
  61        const float amax = vmaxvq_f32(amaxv[0]);
  62
  63        const float d = amax / ((1 << 7) - 1);
  64        const float id = d ? 1.0f/d : 0.0f;
  65
  66        y[i].d = GGML_CPU_FP32_TO_FP16(d);
  67
  68        for (int j = 0; j < 8; j++) {
  69            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
  70            const int32x4_t   vi = vcvtnq_s32_f32(v);
  71
  72            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
  73            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
  74            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
  75            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
  76        }
  77    }
  78#else
  79    GGML_UNUSED(nb);
  80    // scalar
  81    quantize_row_q8_0_ref(x, y, k);
  82#endif
  83}
  84
  85void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  86    assert(k % QK8_1 == 0);
  87    const int nb = k / QK8_1;
  88
  89    block_q8_1 * GGML_RESTRICT y = vy;
  90#if defined(__ARM_NEON)
  91    for (int i = 0; i < nb; i++) {
  92        float32x4_t srcv [8];
  93        float32x4_t asrcv[8];
  94        float32x4_t amaxv[8];
  95
  96        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
  97        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
  98
  99        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
 100        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
 101        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
 102
 103        const float amax = vmaxvq_f32(amaxv[0]);
 104
 105        const float d = amax / ((1 << 7) - 1);
 106        const float id = d ? 1.0f/d : 0.0f;
 107
 108        y[i].d = GGML_CPU_FP32_TO_FP16(d);
 109
 110        int32x4_t accv = vdupq_n_s32(0);
 111
 112        for (int j = 0; j < 8; j++) {
 113            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
 114            const int32x4_t   vi = vcvtnq_s32_f32(v);
 115
 116            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
 117            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
 118            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
 119            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
 120
 121            accv = vaddq_s32(accv, vi);
 122        }
 123
 124        y[i].s = GGML_CPU_FP32_TO_FP16(d * vaddvq_s32(accv));
 125    }
 126#else
 127    GGML_UNUSED(nb);
 128    // scalar
 129    quantize_row_q8_1_ref(x, y, k);
 130#endif
 131}
 132
 133// placeholder implementation for Apple targets
 134void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
 135    quantize_row_q8_K_ref(x, y, k);
 136}
 137
 138//===================================== Dot products =================================
 139
 140void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 141    const int qk = QK8_0;
 142    const int nb = n / qk;
 143
 144    assert(n % qk == 0);
 145#if defined(__ARM_FEATURE_MATMUL_INT8)
 146    assert((nrc == 2) || (nrc == 1));
 147#else
 148    assert(nrc == 1);
 149#endif
 150    UNUSED(nrc);
 151    UNUSED(bx);
 152    UNUSED(by);
 153    UNUSED(bs);
 154
 155    const block_q4_0 * GGML_RESTRICT x = vx;
 156    const block_q8_0 * GGML_RESTRICT y = vy;
 157
 158#if defined(__ARM_FEATURE_MATMUL_INT8)
 159    if (nrc == 2) {
 160        const block_q4_0 * GGML_RESTRICT vx0 = vx;
 161        const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
 162        const block_q8_0 * GGML_RESTRICT vy0 = vy;
 163        const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
 164
 165        float32x4_t sumv0 = vdupq_n_f32(0.0f);
 166
 167        for (int i = 0; i < nb; i++) {
 168            const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i];
 169            const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i];
 170            const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];
 171            const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];
 172
 173            const uint8x16_t m4b = vdupq_n_u8(0x0F);
 174            const int8x16_t  s8b = vdupq_n_s8(0x8);
 175
 176            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
 177            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
 178
 179            // 4-bit -> 8-bit
 180            const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
 181            const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 182            const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
 183            const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 184
 185            // sub 8
 186            const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
 187            const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
 188            const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
 189            const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
 190
 191            // load y
 192            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
 193            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
 194            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
 195            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
 196
 197            float32_t _scale[4] = {
 198                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
 199                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),
 200                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
 201                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)
 202            };
 203            float32x4_t scale = vld1q_f32(_scale);
 204
 205            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 206            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 207
 208            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
 209            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
 210
 211            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
 212            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
 213
 214            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 215            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 216
 217            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
 218                                                l1, r1)), l2, r2)), l3, r3))), scale);
 219        }
 220
 221        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
 222        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
 223
 224        vst1_f32(s,      vget_low_f32 (sumv2));
 225        vst1_f32(s + bs, vget_high_f32(sumv2));
 226
 227        return;
 228    }
 229#endif
 230
 231    int ib = 0;
 232    float sumf = 0;
 233
 234#if defined(__ARM_FEATURE_SVE)
 235    svfloat32_t sumv0 = svdup_n_f32(0.0f);
 236    svfloat32_t sumv1 = svdup_n_f32(0.0f);
 237
 238    const int vector_length = ggml_cpu_get_sve_cnt()*8;
 239
 240    // VLA Implementation using switch case
 241    switch (vector_length) {
 242        case 128:
 243            {
 244                // predicate for activating higher lanes for 4 float32 elements
 245                const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
 246
 247                for (; ib + 1 < nb; ib += 2) {
 248                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
 249                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
 250                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
 251                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 252
 253                    // load x
 254                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
 255                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
 256
 257                    // 4-bit -> 8-bit
 258                    const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
 259                    const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
 260                    const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
 261                    const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
 262
 263                    // sub 8
 264                    const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
 265                    const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
 266                    const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
 267                    const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
 268
 269                    // load y
 270                    const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
 271                    const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
 272                    const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
 273                    const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
 274
 275                    // dot product
 276                    sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
 277                                    svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
 278                                    svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 279                    sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
 280                                    svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
 281                                    svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 282                }
 283
 284                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
 285            } break;
 286        case 256:
 287            {
 288                // predicate for activating higher lanes for 16 int8 elements
 289                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
 290                // predicate for activating lower lanes for  16 int8 elements
 291                const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
 292
 293                for (; ib + 1 < nb; ib += 2) {
 294                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
 295                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
 296                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
 297                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 298
 299                    // load x
 300                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
 301                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
 302
 303                    // 4-bit -> 8-bit
 304                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
 305                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
 306
 307                    // sub 8
 308                    const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
 309                    const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
 310
 311                    // load y
 312                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
 313                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
 314
 315                    // dot product
 316                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
 317                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 318                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
 319                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 320                }
 321
 322                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
 323            } break;
 324        case 512:
 325            {
 326                // predicate for activating higher lanes for 32 int8 elements
 327                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
 328
 329                // predicate for activating higher lanes for 16 int8 elements
 330                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
 331                // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
 332                const svbool_t pl16 = svnot_b_z(ph32, ph16);
 333
 334                for (; ib + 1 < nb; ib += 2) {
 335                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
 336                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
 337                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
 338                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 339
 340                    // load x
 341                    const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
 342                    const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
 343
 344                    // 4-bit -> 8-bit
 345                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
 346                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
 347
 348                    // sub 8
 349                    const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
 350                    const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
 351
 352                    // load y
 353                    const svint8_t qy0 = svld1_s8(ph32, y0->qs);
 354                    const svint8_t qy1 = svld1_s8(ph32, y1->qs);
 355
 356                    // dot product
 357                    sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
 358                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 359                    sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
 360                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 361                }
 362
 363                sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
 364            } break;
 365        default:
 366            assert(false && "Unsupported vector length");
 367            break;
 368    }
 369
 370#elif defined(__ARM_NEON)
 371    float32x4_t sumv0 = vdupq_n_f32(0.0f);
 372    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 373
 374    for (; ib + 1 < nb; ib += 2) {
 375        const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
 376        const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
 377        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
 378        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 379
 380        const uint8x16_t m4b = vdupq_n_u8(0x0F);
 381        const int8x16_t  s8b = vdupq_n_s8(0x8);
 382
 383        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
 384        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
 385
 386        // 4-bit -> 8-bit
 387        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
 388        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 389        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
 390        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 391
 392        // sub 8
 393        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
 394        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
 395        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
 396        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
 397
 398        // load y
 399        const int8x16_t v1_0l = vld1q_s8(y0->qs);
 400        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
 401        const int8x16_t v1_1l = vld1q_s8(y1->qs);
 402        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 403
 404        // dot product into int32x4_t
 405        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
 406        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
 407
 408        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 409        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 410    }
 411
 412    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
 413#endif
 414    for (; ib < nb; ++ib) {
 415        int sumi0 = 0;
 416        int sumi1 = 0;
 417
 418        for (int j = 0; j < qk/2; ++j) {
 419            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
 420            const int v1 = (x[ib].qs[j] >>   4) - 8;
 421
 422            sumi0 += (v0 * y[ib].qs[j]);
 423            sumi1 += (v1 * y[ib].qs[j + qk/2]);
 424        }
 425
 426        int sumi = sumi0 + sumi1;
 427        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
 428    }
 429
 430    *s = sumf;
 431}
 432
 433void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 434    const int qk = QK8_1;
 435    const int nb = n / qk;
 436
 437    assert(n % qk == 0);
 438#if defined(__ARM_FEATURE_MATMUL_INT8)
 439    assert((nrc == 2) || (nrc == 1));
 440#else
 441    assert(nrc == 1);
 442#endif
 443    UNUSED(nrc);
 444    UNUSED(bx);
 445    UNUSED(by);
 446    UNUSED(bs);
 447
 448    const block_q4_1 * GGML_RESTRICT x = vx;
 449    const block_q8_1 * GGML_RESTRICT y = vy;
 450
 451#if defined(__ARM_FEATURE_MATMUL_INT8)
 452    if (nrc == 2) {
 453        const block_q4_1 * GGML_RESTRICT vx0 = vx;
 454        const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
 455        const block_q8_1 * GGML_RESTRICT vy0 = vy;
 456        const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
 457
 458        float32x4_t sumv0 = vdupq_n_f32(0.0f);
 459        float32x4_t summs0 = vdupq_n_f32(0.0f);
 460
 461        for (int i = 0; i < nb; i++) {
 462            const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i];
 463            const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i];
 464            const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i];
 465            const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i];
 466
 467            float32_t summs_t[4] = {
 468                GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y0->s),
 469                GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y0->s),
 470                GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y1->s),
 471                GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y1->s)
 472            };
 473            summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
 474
 475            const uint8x16_t m4b = vdupq_n_u8(0x0F);
 476
 477            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
 478            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
 479
 480            // 4-bit -> 8-bit
 481            const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
 482            const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 483            const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
 484            const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 485
 486            // load y
 487            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
 488            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
 489            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
 490            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
 491
 492            // mmla into int32x4_t
 493            float32_t _scale[4] = {
 494                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
 495                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),
 496                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
 497                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)
 498            };
 499            float32x4_t scale = vld1q_f32(_scale);
 500
 501            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 502            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 503
 504            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
 505            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
 506
 507            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
 508            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
 509
 510            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 511            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 512            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
 513                                                l1, r1)), l2, r2)), l3, r3))), scale);
 514        }
 515
 516        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
 517        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
 518
 519        sumv2 = vaddq_f32(sumv2, summs0);
 520
 521        vst1_f32(s,      vget_low_f32 (sumv2));
 522        vst1_f32(s + bs, vget_high_f32(sumv2));
 523
 524        return;
 525    }
 526#endif
 527
 528    int ib = 0;
 529    float sumf = 0;
 530
 531#if defined(__ARM_NEON)
 532    float32x4_t sumv0 = vdupq_n_f32(0.0f);
 533    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 534
 535    float summs = 0;
 536
 537    for (; ib + 1 < nb; ib += 2) {
 538        const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0];
 539        const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1];
 540        const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
 541        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
 542
 543        summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s) + GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
 544
 545        const uint8x16_t m4b = vdupq_n_u8(0x0F);
 546
 547        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
 548        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
 549
 550        // 4-bit -> 8-bit
 551        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
 552        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 553        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
 554        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 555
 556        // load y
 557        const int8x16_t v1_0l = vld1q_s8(y0->qs);
 558        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
 559        const int8x16_t v1_1l = vld1q_s8(y1->qs);
 560        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 561
 562        // dot product into int32x4_t
 563        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
 564        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
 565
 566        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 567        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 568    }
 569
 570    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
 571
 572#endif
 573    for (; ib < nb; ++ib) {
 574        int sumi0 = 0;
 575        int sumi1 = 0;
 576
 577        for (int j = 0; j < qk/2; ++j) {
 578            const int v0 = (x[ib].qs[j] & 0x0F);
 579            const int v1 = (x[ib].qs[j] >>   4);
 580
 581            sumi0 += (v0 * y[ib].qs[j]);
 582            sumi1 += (v1 * y[ib].qs[j + qk/2]);
 583        }
 584
 585        int sumi = sumi0 + sumi1;
 586        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
 587    }
 588
 589    *s = sumf;
 590}
 591
 592void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 593    assert(nrc == 1);
 594    UNUSED(nrc);
 595    UNUSED(bx);
 596    UNUSED(by);
 597    UNUSED(bs);
 598    assert(n % QK_MXFP4 == 0);
 599    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
 600
 601    const block_mxfp4 * GGML_RESTRICT x = vx;
 602    const block_q8_0 * GGML_RESTRICT y = vy;
 603
 604    const int nb = n / QK_MXFP4;
 605
 606    int ib = 0;
 607    float sumf = 0;
 608
 609#if defined __ARM_NEON
 610    const int8x16_t values = vld1q_s8(kvalues_mxfp4);
 611    const uint8x16_t m4b = vdupq_n_u8(0x0f);
 612    uint8x16x2_t q4bits;
 613    int8x16x4_t q4b;
 614    int8x16x4_t q8b;
 615    int32x4_t prod_1;
 616    int32x4_t prod_2;
 617
 618    for (; ib + 1 < nb; ib += 2) {
 619        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
 620        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
 621        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);
 622        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);
 623        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);
 624        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);
 625
 626        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
 627        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
 628        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
 629        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
 630
 631        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
 632        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
 633
 634        sumf +=
 635            GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
 636            GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
 637    }
 638
 639#endif
 640    for (; ib < nb; ++ib) {
 641        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
 642        int sumi1 = 0;
 643        int sumi2 = 0;
 644        for (int j = 0; j < QK_MXFP4/2; ++j) {
 645            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
 646            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
 647        }
 648        sumf += d * (sumi1 + sumi2);
 649    }
 650    *s = sumf;
 651}
 652
 653void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 654    const int qk = QK8_0;
 655    const int nb = n / qk;
 656
 657    int ib = 0;
 658    float sumf = 0;
 659
 660    assert(n % qk == 0);
 661    assert(qk == QK5_0);
 662    assert(nrc == 1);
 663    UNUSED(nrc);
 664    UNUSED(bx);
 665    UNUSED(by);
 666    UNUSED(bs);
 667
 668    const block_q5_0 * GGML_RESTRICT x = vx;
 669    const block_q8_0 * GGML_RESTRICT y = vy;
 670
 671#if defined(__ARM_NEON)
 672    float32x4_t sumv0 = vdupq_n_f32(0.0f);
 673    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 674
 675    uint32_t qh0;
 676    uint32_t qh1;
 677
 678    uint64_t tmp0[4];
 679    uint64_t tmp1[4];
 680
 681    for (; ib + 1 < nb; ib += 2) {
 682        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
 683        const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
 684        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
 685        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 686
 687        const uint8x16_t m4b = vdupq_n_u8(0x0F);
 688
 689        // extract the 5th bit via lookup table ((!b) << 4)
 690        memcpy(&qh0, x0->qh, sizeof(qh0));
 691        memcpy(&qh1, x1->qh, sizeof(qh1));
 692
 693        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
 694        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
 695        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
 696        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
 697
 698        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
 699        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
 700        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
 701        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
 702
 703        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
 704        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
 705        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
 706        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
 707
 708        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
 709        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
 710
 711        // 4-bit -> 8-bit
 712        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
 713        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 714        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
 715        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 716
 717        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
 718        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
 719        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
 720        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
 721        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
 722
 723        // load y
 724        const int8x16_t v1_0l = vld1q_s8(y0->qs);
 725        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
 726        const int8x16_t v1_1l = vld1q_s8(y1->qs);
 727        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 728
 729        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
 730                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
 731                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 732        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
 733                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
 734                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 735    }
 736
 737    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
 738
 739#endif
 740    for (; ib < nb; ++ib) {
 741        uint32_t qh;
 742        memcpy(&qh, x[ib].qh, sizeof(qh));
 743
 744        int sumi0 = 0;
 745        int sumi1 = 0;
 746
 747        for (int j = 0; j < qk/2; ++j) {
 748            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
 749            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
 750
 751            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
 752            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
 753
 754            sumi0 += (x0 * y[ib].qs[j]);
 755            sumi1 += (x1 * y[ib].qs[j + qk/2]);
 756        }
 757
 758        int sumi = sumi0 + sumi1;
 759        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
 760    }
 761
 762    *s = sumf;
 763}
 764
 765void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 766    const int qk = QK8_1;
 767    const int nb = n / qk;
 768
 769    int ib = 0;
 770    float sumf = 0;
 771
 772    assert(n % qk == 0);
 773    assert(qk == QK5_1);
 774    assert(nrc == 1);
 775    UNUSED(nrc);
 776    UNUSED(bx);
 777    UNUSED(by);
 778    UNUSED(bs);
 779
 780    const block_q5_1 * GGML_RESTRICT x = vx;
 781    const block_q8_1 * GGML_RESTRICT y = vy;
 782
 783#if defined(__ARM_NEON)
 784    float32x4_t sumv0 = vdupq_n_f32(0.0f);
 785    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 786
 787    float summs0 = 0.0f;
 788    float summs1 = 0.0f;
 789
 790    uint32_t qh0;
 791    uint32_t qh1;
 792
 793    uint64_t tmp0[4];
 794    uint64_t tmp1[4];
 795
 796    for (; ib + 1 < nb; ib += 2) {
 797        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
 798        const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
 799        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
 800        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
 801
 802        const uint8x16_t m4b = vdupq_n_u8(0x0F);
 803
 804        summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
 805        summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
 806
 807        // extract the 5th bit via lookup table ((b) << 4)
 808        memcpy(&qh0, x0->qh, sizeof(qh0));
 809        memcpy(&qh1, x1->qh, sizeof(qh1));
 810
 811        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
 812        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
 813        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
 814        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
 815
 816        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
 817        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
 818        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
 819        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
 820
 821        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
 822        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
 823        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
 824        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
 825
 826        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
 827        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
 828
 829        // 4-bit -> 8-bit
 830        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
 831        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 832        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
 833        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 834
 835        // add high bit
 836        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
 837        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
 838        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
 839        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
 840
 841        // load y
 842        const int8x16_t v1_0l = vld1q_s8(y0->qs);
 843        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
 844        const int8x16_t v1_1l = vld1q_s8(y1->qs);
 845        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 846
 847        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
 848                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
 849                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
 850        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
 851                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
 852                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
 853    }
 854
 855    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
 856
 857#endif
 858    for (; ib < nb; ++ib) {
 859        uint32_t qh;
 860        memcpy(&qh, x[ib].qh, sizeof(qh));
 861
 862        int sumi0 = 0;
 863        int sumi1 = 0;
 864
 865        for (int j = 0; j < qk/2; ++j) {
 866            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
 867            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
 868
 869            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
 870            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
 871
 872            sumi0 += (x0 * y[ib].qs[j]);
 873            sumi1 += (x1 * y[ib].qs[j + qk/2]);
 874        }
 875
 876        int sumi = sumi0 + sumi1;
 877        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
 878    }
 879
 880    *s = sumf;
 881}
 882
 883void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 884    const int qk = QK8_0;
 885    const int nb = n / qk;
 886
 887    assert(n % qk == 0);
 888#if defined(__ARM_FEATURE_MATMUL_INT8)
 889    assert((nrc == 2) || (nrc == 1));
 890#else
 891    assert(nrc == 1);
 892#endif
 893    UNUSED(nrc);
 894    UNUSED(bx);
 895    UNUSED(by);
 896    UNUSED(bs);
 897
 898    const block_q8_0 * GGML_RESTRICT x = vx;
 899    const block_q8_0 * GGML_RESTRICT y = vy;
 900
 901#if defined(__ARM_FEATURE_MATMUL_INT8)
 902    if (nrc == 2) {
 903        const block_q8_0 * GGML_RESTRICT vx0 = vx;
 904        const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
 905        const block_q8_0 * GGML_RESTRICT vy0 = vy;
 906        const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
 907
 908        float32x4_t sumv0 = vdupq_n_f32(0.0f);
 909
 910        for (int i = 0; i < nb; i++) {
 911            const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i];
 912            const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];
 913
 914            const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i];
 915            const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];
 916
 917            const int8x16_t x0_l = vld1q_s8(b_x0->qs);
 918            const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
 919            const int8x16_t x1_l = vld1q_s8(b_x1->qs);
 920            const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
 921
 922            // load y
 923            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
 924            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
 925            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
 926            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
 927
 928            float32_t _scale[4] = {
 929                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
 930                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),
 931                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
 932                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)
 933            };
 934            float32x4_t scale = vld1q_f32(_scale);
 935
 936            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 937            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
 938
 939            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
 940            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
 941
 942            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
 943            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
 944
 945            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 946            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 947
 948            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
 949                                                l1, r1)), l2, r2)), l3, r3))), scale);
 950        }
 951
 952        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
 953        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
 954
 955        vst1_f32(s,      vget_low_f32 (sumv2));
 956        vst1_f32(s + bs, vget_high_f32(sumv2));
 957
 958        return;
 959    }
 960#endif
 961
 962    int ib = 0;
 963    float sumf = 0;
 964
 965#if defined(__ARM_FEATURE_SVE)
 966    svfloat32_t sumv0 = svdup_n_f32(0.0f);
 967    svfloat32_t sumv1 = svdup_n_f32(0.0f);
 968
 969    const int vector_length = ggml_cpu_get_sve_cnt()*8;
 970
 971    //VLA Implemenation for SVE
 972    switch (vector_length) {
 973        case 128:
 974            {
 975                // predicate for activating lanes for 16 Int8 elements
 976                const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
 977                const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
 978
 979                for (; ib + 1 < nb; ib += 2) {
 980                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
 981                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
 982                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
 983                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
 984
 985                    // load x
 986                    const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
 987                    const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
 988                    const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
 989                    const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
 990
 991                    // load y
 992                    const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
 993                    const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
 994                    const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
 995                    const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
 996
 997                    sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
 998                                    svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
 999                                    svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
1000                    sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
1001                                    svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
1002                                    svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
1003                }
1004
1005                sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
1006            } break;
1007        case 256:
1008            {
1009                //printf("sve256");
1010                for (; ib + 1 < nb; ib += 2) {
1011                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
1012                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
1013                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
1014                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
1015
1016                    // load x
1017                    const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
1018                    const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
1019
1020                    // load y
1021                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
1022                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
1023
1024                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
1025                                svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
1026                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
1027                                svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
1028                }
1029
1030                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
1031            } break;
1032        case 512:
1033            {
1034                // predicate for activating high 256 bit
1035                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
1036                // predicate for activating low 256 bit
1037                const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
1038
1039                // predicate for activating high lanes for 8 float32 elements
1040                const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
1041                // predicate for activating low lanes for 8 float32 elements
1042                const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
1043
1044                svfloat32_t sumv00 = svdup_n_f32(0.0f);
1045
1046                for (; ib + 1 < nb; ib += 2) {
1047                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
1048                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
1049                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
1050                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
1051
1052                    //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
1053                    // and add them to make one 64 element vector
1054                    // load x
1055                    const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
1056                          svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
1057
1058                    qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
1059
1060                    // load y
1061                    const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
1062                          svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
1063
1064                    qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
1065
1066                    // scale creation
1067                    const float32_t deq1 = GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d);
1068                    const float32_t deq2 = GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d);
1069
1070                    // duplicate deq1 in first half of vector and deq2 in second half of vector
1071                    const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
1072
1073                    const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
1074
1075                    sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
1076                }
1077
1078                sumf = svaddv_f32(svptrue_b32(), sumv00);
1079                break;
1080            }
1081        default:
1082            assert(false && "Unsupported vector length");
1083            break;
1084    }
1085#elif defined(__ARM_NEON)
1086    float32x4_t sumv0 = vdupq_n_f32(0.0f);
1087    float32x4_t sumv1 = vdupq_n_f32(0.0f);
1088
1089    for (; ib + 1 < nb; ib += 2) {
1090        const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
1091        const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
1092        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
1093        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
1094
1095        const int8x16_t x0_0 = vld1q_s8(x0->qs);
1096        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
1097        const int8x16_t x1_0 = vld1q_s8(x1->qs);
1098        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
1099
1100        // load y
1101        const int8x16_t y0_0 = vld1q_s8(y0->qs);
1102        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
1103        const int8x16_t y1_0 = vld1q_s8(y1->qs);
1104        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
1105
1106        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
1107                        ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
1108                        ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
1109
1110        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
1111                        ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
1112                        ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
1113    }
1114
1115    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
1116#endif
1117    for (; ib < nb; ++ib) {
1118        int sumi = 0;
1119
1120        for (int j = 0; j < qk; j++) {
1121            sumi += x[ib].qs[j]*y[ib].qs[j];
1122        }
1123
1124        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
1125    }
1126
1127    *s = sumf;
1128}
1129
1130void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1131    assert(nrc == 1);
1132    UNUSED(nrc);
1133    UNUSED(bx);
1134    UNUSED(by);
1135    UNUSED(bs);
1136
1137    const block_tq1_0 * GGML_RESTRICT x = vx;
1138    const block_q8_K  * GGML_RESTRICT y = vy;
1139
1140    const int nb = n / QK_K;
1141
1142#if defined(__ARM_NEON)
1143    float sumf = 0.0f;
1144
1145    uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
1146
1147    const uint8x16_t shift = vld1q_u8(k_shift);
1148
1149    for (int i = 0; i < nb; ++i) {
1150#if defined(__ARM_FEATURE_DOTPROD)
1151        int32x4_t sumi0 = vdupq_n_s32(0);
1152        int32x4_t sumi1 = vdupq_n_s32(0);
1153#else
1154        int16x8_t sumi0 = vdupq_n_s16(0);
1155        int16x8_t sumi1 = vdupq_n_s16(0);
1156#endif
1157
1158        // first 32 bytes of 5 elements
1159        {
1160            uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
1161            uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
1162            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
1163            uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
1164            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
1165            uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
1166            uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
1167            uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
1168            uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
1169            uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
1170
1171            // multiply by 3 and keep the 2 bits above 8 bits
1172            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
1173            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
1174            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
1175            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
1176            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
1177            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
1178            int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
1179            int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
1180            int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
1181            int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
1182
1183            const int8x16_t qy0 = vld1q_s8(y[i].qs +   0);
1184            const int8x16_t qy1 = vld1q_s8(y[i].qs +  16);
1185            const int8x16_t qy2 = vld1q_s8(y[i].qs +  32);
1186            const int8x16_t qy3 = vld1q_s8(y[i].qs +  48);
1187            const int8x16_t qy4 = vld1q_s8(y[i].qs +  64);
1188            const int8x16_t qy5 = vld1q_s8(y[i].qs +  80);
1189            const int8x16_t qy6 = vld1q_s8(y[i].qs +  96);
1190            const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
1191            const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
1192            const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
1193
1194#if defined(__ARM_FEATURE_DOTPROD)
1195            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1196            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1197            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1198            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1199            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1200            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1201            sumi0 = vdotq_s32(sumi0, sqx6, qy6);
1202            sumi1 = vdotq_s32(sumi1, sqx7, qy7);
1203            sumi0 = vdotq_s32(sumi0, sqx8, qy8);
1204            sumi1 = vdotq_s32(sumi1, sqx9, qy9);
1205#else
1206            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1207            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1208            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1209            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1210            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1211            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1212            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1213            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1214            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1215            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1216            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1217            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1218            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
1219            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
1220            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
1221            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
1222            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
1223            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
1224            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
1225            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
1226#endif
1227        }
1228
1229        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1230        {
1231            uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
1232            uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
1233            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
1234            uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
1235            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
1236            uint32_t qh;
1237            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
1238            uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
1239            qx5 = vmulq_u8(qx5, shift);
1240
1241            // multiply by 3 and keep the 2 bits above 8 bits
1242            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
1243            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
1244            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
1245            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
1246            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
1247            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
1248
1249            const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
1250            const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
1251            const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
1252            const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
1253            const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
1254            const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
1255
1256#if defined(__ARM_FEATURE_DOTPROD)
1257            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1258            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1259            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1260            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1261            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1262            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1263#else
1264            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1265            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1266            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1267            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1268            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1269            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1270            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1271            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1272            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1273            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1274            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1275            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1276#endif
1277        }
1278
1279        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
1280        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
1281
1282        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1283
1284#if defined(__ARM_FEATURE_DOTPROD)
1285        sumi0 = vaddq_s32(sumi0, sumi1);
1286        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
1287
1288        sumf += d * (float) vaddvq_s32(sumi0);
1289#else
1290        sumi0 = vaddq_s16(sumi0, sumi1);
1291        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
1292
1293        sumf += d * (float) vaddlvq_s16(sumi0);
1294#endif
1295    }
1296
1297    *s = sumf;
1298
1299#else
1300    UNUSED(x);
1301    UNUSED(y);
1302    UNUSED(nb);
1303    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1304#endif
1305}
1306
1307void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1308    assert(nrc == 1);
1309    UNUSED(nrc);
1310    UNUSED(bx);
1311    UNUSED(by);
1312    UNUSED(bs);
1313
1314    const block_tq2_0 * GGML_RESTRICT x = vx;
1315    const block_q8_K  * GGML_RESTRICT y = vy;
1316
1317    const int nb = n / QK_K;
1318
1319#if defined(__ARM_NEON)
1320    float sumf = 0.0f;
1321
1322    const uint8x16_t m3 = vdupq_n_u8(3);
1323
1324    for (int i = 0; i < nb; ++i) {
1325#if defined(__ARM_FEATURE_DOTPROD)
1326        int32x4_t sumi0 = vdupq_n_s32(0);
1327        int32x4_t sumi1 = vdupq_n_s32(0);
1328#else
1329        int16x8_t sumi0 = vdupq_n_s16(0);
1330        int16x8_t sumi1 = vdupq_n_s16(0);
1331#endif
1332
1333        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1334            uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
1335            uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
1336            uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
1337            uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
1338            uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
1339            uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
1340            uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
1341            uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
1342
1343            int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
1344            int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
1345            int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
1346            int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
1347            int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
1348            int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
1349            int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
1350            int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
1351
1352            const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 +   0);
1353            const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 +  16);
1354            const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 +  32);
1355            const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 +  48);
1356            const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 +  64);
1357            const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 +  80);
1358            const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 +  96);
1359            const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
1360
1361#if defined(__ARM_FEATURE_DOTPROD)
1362            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1363            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1364            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1365            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1366            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1367            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1368            sumi0 = vdotq_s32(sumi0, sqx6, qy6);
1369            sumi1 = vdotq_s32(sumi1, sqx7, qy7);
1370#else
1371            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1372            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1373            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1374            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1375            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1376            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1377            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1378            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1379            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1380            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1381            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1382            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1383            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
1384            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
1385            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
1386            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
1387#endif
1388        }
1389
1390        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
1391        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
1392
1393        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1394
1395#if defined(__ARM_FEATURE_DOTPROD)
1396        sumi0 = vaddq_s32(sumi0, sumi1);
1397        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
1398
1399        sumf += d * (float) vaddvq_s32(sumi0);
1400#else
1401        sumi0 = vaddq_s16(sumi0, sumi1);
1402        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
1403
1404        sumf += d * (float) vaddlvq_s16(sumi0);
1405#endif
1406    }
1407
1408    *s = sumf;
1409
1410#else
1411    UNUSED(x);
1412    UNUSED(y);
1413    UNUSED(nb);
1414    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1415#endif
1416}
1417
1418void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1419    assert(nrc == 1);
1420    UNUSED(nrc);
1421    UNUSED(bx);
1422    UNUSED(by);
1423    UNUSED(bs);
1424
1425    const block_q2_K * GGML_RESTRICT x = vx;
1426    const block_q8_K * GGML_RESTRICT y = vy;
1427
1428    const int nb = n / QK_K;
1429
1430#ifdef __ARM_FEATURE_SVE
1431    const int vector_length = svcntb()*8;
1432    const svuint8_t m3s = svdup_n_u8(0x3);
1433    const svuint32_t m4s = svdup_n_u32(0xF);
1434    const svint32_t vzero_sv = svdup_n_s32(0);
1435    svfloat32_t acc_sum = svdup_n_f32(0);
1436    svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
1437
1438    switch (vector_length) {
1439        case 128:
1440            for (int i = 0; i < nb; ++i) {
1441                const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1442                svfloat32_t d_broad = svdup_n_f32((float32_t)d);
1443                const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1444                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
1445
1446                const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1447                const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;
1448                const uint8_t * GGML_RESTRICT sc = x[i].scales;
1449
1450                svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
1451                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1452
1453                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
1454                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1455
1456                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
1457                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
1458
1459                const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
1460
1461                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
1462                const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1463
1464                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
1465                const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1466
1467                q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
1468                q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
1469
1470                svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
1471
1472                svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
1473
1474                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
1475
1476                svint32_t sumi1 = svdup_n_s32(0);
1477
1478                {
1479                    const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
1480                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
1481                    svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1482                    const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
1483
1484                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
1485
1486                    const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
1487                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
1488                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1489
1490                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
1491
1492                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
1493                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1494
1495                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
1496
1497                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
1498                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1499
1500                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
1501
1502
1503                    const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
1504
1505                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
1506                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1507
1508                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
1509
1510                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
1511                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1512
1513                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
1514
1515                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
1516                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1517
1518                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
1519
1520                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
1521                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1522
1523                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
1524
1525                    //-------------------------------
1526
1527                    q2 += 32;
1528                    const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
1529                    const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
1530
1531                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
1532                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1533
1534                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
1535
1536                    const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
1537                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
1538                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1539
1540                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
1541
1542
1543                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
1544                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1545
1546                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
1547
1548                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
1549                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1550
1551                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
1552
1553
1554                    const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
1555
1556                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
1557                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1558
1559                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
1560
1561                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
1562                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1563
1564                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
1565
1566
1567
1568                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
1569                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1570
1571                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
1572
1573                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
1574                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1575
1576                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
1577                }
1578                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
1579            }
1580            *s = svaddv_f32(svptrue_b32(), acc_sum);
1581            break;
1582
1583        case 256:
1584        case 512:
1585            for (int i = 0; i < nb; ++i) {
1586                const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1587                svfloat32_t d_broad = svdup_n_f32((float32_t)d);
1588                const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1589                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
1590
1591                const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1592                const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;
1593                const uint8_t * GGML_RESTRICT sc = x[i].scales;
1594
1595                const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
1596                const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
1597                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
1598                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
1599
1600                const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
1601                const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
1602                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
1603
1604                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
1605
1606                svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
1607
1608                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
1609
1610                svint32_t sumi1 = svdup_n_s32(0);
1611
1612                {
1613                    const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
1614                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
1615                    svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1616
1617                    svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
1618                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1619
1620                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
1621                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1622
1623                    svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
1624                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
1625
1626                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
1627                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1628
1629                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
1630                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1631
1632                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
1633                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1634
1635                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
1636                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1637
1638                    q2 += 32;
1639
1640                    const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
1641                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
1642                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1643
1644                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
1645                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1646
1647                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
1648                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1649
1650                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
1651                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1652
1653                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
1654                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1655
1656                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
1657                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1658
1659                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
1660                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1661
1662                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
1663                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1664                }
1665                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
1666            }
1667            *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
1668            break;
1669
1670        default:
1671            assert(false && "Unsupported vector length");
1672            break;
1673    }
1674
1675#elif __ARM_NEON
1676    const uint8x16_t m3 = vdupq_n_u8(0x3);
1677    const uint8x16_t m4 = vdupq_n_u8(0xF);
1678
1679    const int32x4_t vzero = vdupq_n_s32(0);
1680
1681    ggml_int8x16x2_t q2bytes;
1682    uint8_t aux[16];
1683
1684    float sum = 0;
1685
1686    for (int i = 0; i < nb; ++i) {
1687        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1688        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1689
1690        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1691        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1692        const uint8_t * GGML_RESTRICT sc = x[i].scales;
1693
1694        const uint8x16_t mins_and_scales = vld1q_u8(sc);
1695        const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
1696        vst1q_u8(aux, scales);
1697
1698        const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
1699        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
1700        const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
1701        const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
1702                                       vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
1703        const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
1704                                       vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
1705        sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
1706
1707        int isum = 0;
1708        int is = 0;
1709
1710// We use this macro instead of a function call because for some reason
1711// the code runs 2-3% slower, even if the function is declared inline
1712#define MULTIPLY_ACCUM_WITH_SCALE(index)\
1713        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
1714        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
1715
1716#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
1717        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
1718        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
1719        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
1720        MULTIPLY_ACCUM_WITH_SCALE((index));
1721
1722        for (int j = 0; j < QK_K/128; ++j) {
1723            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
1724
1725            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
1726            q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
1727            q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
1728
1729            MULTIPLY_ACCUM_WITH_SCALE(0);
1730
1731            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
1732            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
1733            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
1734
1735            is += 8;
1736        }
1737
1738        sum += d * isum;
1739    }
1740
1741    *s = sum;
1742
1743#else
1744    UNUSED(x);
1745    UNUSED(y);
1746    UNUSED(nb);
1747    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1748#endif
1749}
1750
1751void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1752    assert(n % QK_K == 0);
1753    assert(nrc == 1);
1754    UNUSED(nrc);
1755    UNUSED(bx);
1756    UNUSED(by);
1757    UNUSED(bs);
1758
1759    const uint32_t kmask1 = 0x03030303;
1760    const uint32_t kmask2 = 0x0f0f0f0f;
1761
1762    const block_q3_K * GGML_RESTRICT x = vx;
1763    const block_q8_K * GGML_RESTRICT y = vy;
1764
1765    const int nb = n / QK_K;
1766
1767#if defined(__ARM_FEATURE_SVE)
1768
1769    uint32_t aux[3];
1770    uint32_t utmp[4];
1771
1772    const int8_t m32 = 32;
1773    const int vector_length = svcntb()*8;
1774    const svuint8_t m3b_sv = svdup_n_u8(0x3);
1775    const svint32_t vzero_sv = svdup_n_s32(0);
1776
1777    const svuint8_t m0_sv = svdup_n_u8(1);
1778    const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
1779    const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
1780    const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
1781
1782    float sum = 0;
1783
1784    for (int i = 0; i < nb; ++i) {
1785
1786        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1787
1788        const uint8_t * GGML_RESTRICT q3_sv = x[i].qs;
1789        const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask;
1790        const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;
1791
1792        // Set up scales
1793        memcpy(aux, x[i].scales, 12);
1794        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
1795        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
1796        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
1797        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
1798
1799        int8_t * scale = (int8_t *)utmp;
1800
1801        for (int j = 0; j < 16; ++j) scale[j] -= m32;
1802
1803        switch (vector_length) {
1804            case 128:
1805                {
1806                    svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
1807                    svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
1808                    svuint8_t q3h_sv;
1809
1810                    svint32_t sumi1_1 = svdup_n_s32(0);
1811                    svint8_t q3bytes_sv;
1812
1813                    for (int j = 0; j < QK_K/128; ++j) {
1814
1815                        const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
1816                        const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
1817                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1818                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1819
1820                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
1821                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1822
1823                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
1824
1825                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
1826                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1827
1828                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
1829
1830                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1831                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1832
1833                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
1834                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1835
1836                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
1837
1838                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
1839                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1840
1841                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
1842
1843
1844                        scale += 4;
1845                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1846                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1847
1848                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
1849                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1850
1851                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
1852
1853                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
1854                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1855
1856                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
1857
1858
1859                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1860                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1861
1862                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
1863                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1864
1865                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
1866
1867                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
1868                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1869
1870                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
1871
1872                        if (j == 0) {
1873                            qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
1874                            qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
1875                        }
1876
1877                        scale += 4;
1878                    }
1879
1880                    sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
1881                } break;
1882            case 256:
1883            case 512:
1884                {
1885                    svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
1886                    svuint8_t q3h_sv;
1887
1888                    svint32_t sumi1_1 = svdup_n_s32(0);
1889                    svint8_t q3bytes_sv;
1890
1891                    for (int j = 0; j < QK_K/128; ++j) {
1892
1893                        const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
1894                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1895                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1896
1897                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
1898                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1899
1900
1901                        svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
1902                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
1903
1904                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
1905                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1906
1907                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
1908                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
1909
1910                        scale += 4;
1911                        q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1912                        q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1913
1914                        q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
1915                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1916
1917                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
1918                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
1919
1920                        q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
1921                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1922
1923                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
1924                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
1925
1926                        if (j == 0) {
1927                            qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
1928                        }
1929
1930                        scale += 4;
1931                    }
1932
1933                    sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
1934                } break;
1935            default:
1936                assert(false && "Unsupported vector length");
1937                break;
1938        }
1939    }
1940    *s = sum;
1941
1942#elif __ARM_NEON
1943
1944    uint32_t aux[3];
1945    uint32_t utmp[4];
1946
1947    const uint8x16_t m3b = vdupq_n_u8(0x3);
1948    const int32x4_t  vzero = vdupq_n_s32(0);
1949
1950    const uint8x16_t m0 = vdupq_n_u8(1);
1951    const uint8x16_t m1 = vshlq_n_u8(m0, 1);
1952    const uint8x16_t m2 = vshlq_n_u8(m0, 2);
1953    const uint8x16_t m3 = vshlq_n_u8(m0, 3);
1954    const int8_t m32 = 32;
1955
1956    ggml_int8x16x4_t q3bytes;
1957
1958    float sum = 0;
1959
1960    for (int i = 0; i < nb; ++i) {
1961
1962        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1963
1964        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1965        const uint8_t * GGML_RESTRICT qh = x[i].hmask;
1966        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1967
1968        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
1969
1970        ggml_uint8x16x4_t q3h;
1971
1972        int32_t isum = 0;
1973
1974        // Set up scales
1975        memcpy(aux, x[i].scales, 12);
1976        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
1977        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
1978        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
1979        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
1980
1981        int8_t * scale = (int8_t *)utmp;
1982        for (int j = 0; j < 16; ++j) scale[j] -= m32;
1983
1984        for (int j = 0; j < QK_K/128; ++j) {
1985
1986            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
1987            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
1988            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
1989
1990            q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
1991            q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
1992            q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
1993            q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
1994
1995            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
1996            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
1997            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
1998            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
1999
2000            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
2001            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
2002            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
2003            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
2004
2005            scale += 4;
2006
2007            q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
2008            q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
2009            q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
2010            q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
2011
2012            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
2013            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
2014            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
2015            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
2016
2017            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
2018            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
2019            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
2020            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
2021
2022            scale += 4;
2023
2024            if (j == 0) {
2025                qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
2026                qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
2027            }
2028
2029        }
2030        sum += d * isum;
2031
2032    }
2033
2034    *s = sum;
2035
2036#else
2037    UNUSED(kmask1);
2038    UNUSED(kmask2);
2039    UNUSED(x);
2040    UNUSED(y);
2041    UNUSED(nb);
2042    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2043#endif
2044
2045}
2046
2047#ifdef __ARM_FEATURE_SVE
2048static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
2049    const svbool_t pg_all   = svptrue_pat_b32(SV_VL4);
2050    const svbool_t pg_false = svpfalse_b();            // 0x0000
2051    const svbool_t pg_lo_8  = svwhilelt_b8_s32(0,  8); // 0x00ff
2052    const svbool_t pg_odd   = svzip1_b32(pg_false, pg_lo_8);
2053
2054    svuint32_t vutmp_hi, vutmp_lo;
2055    svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
2056    vutmp_hi = svzip1_u32(vx01, vx01);
2057    vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
2058    vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
2059    const svuint32_t vx2 = svdup_u32(vx_scales[2]);
2060    vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
2061    vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
2062    svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
2063    return vutmp;
2064}
2065#endif
2066
2067void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2068    assert(n % QK_K == 0);
2069#ifdef __ARM_FEATURE_MATMUL_INT8
2070    assert((nrc == 2) || (nrc == 1));
2071#else
2072    assert(nrc == 1);
2073#endif
2074    UNUSED(nrc);
2075    UNUSED(bx);
2076    UNUSED(by);
2077    UNUSED(bs);
2078
2079    const block_q4_K * GGML_RESTRICT x = vx;
2080    const block_q8_K * GGML_RESTRICT y = vy;
2081
2082    const int nb = n / QK_K;
2083
2084    static const uint32_t kmask1 = 0x3f3f3f3f;
2085    static const uint32_t kmask2 = 0x0f0f0f0f;
2086    static const uint32_t kmask3 = 0x03030303;
2087
2088    uint32_t utmp[4];
2089#ifdef __ARM_FEATURE_SVE
2090    const int vector_length = ggml_cpu_get_sve_cnt()*8;
2091#endif
2092
2093#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2094    if (nrc == 2) {
2095        svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
2096
2097        const block_q4_K * GGML_RESTRICT vx0 = vx;
2098        const block_q8_K * GGML_RESTRICT vy0 = vy;
2099        const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
2100        const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
2101
2102        union {
2103            uint32_t u32[8];
2104            uint64_t u64[4];
2105        } new_utmp;
2106
2107        svfloat32_t sumf1 = svdup_n_f32(0);
2108
2109        switch (vector_length) {
2110            case 128:
2111                {
2112                    svbool_t pg_false = svpfalse_b();
2113                    svbool_t pg_lo_8  = svwhilelt_b8_s32(0,  8);
2114                    svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
2115                    svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
2116                    svbool_t pg128_all  = svptrue_pat_b8(SV_VL16);
2117                    for (int i = 0; i < nb; ++i) {
2118                        svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2119                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2120                        svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
2121                        svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
2122                        svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2123                        svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
2124                        const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
2125                        const int8_t  * GGML_RESTRICT q8_0 = vy0[i].qs;
2126                        const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
2127                        const int8_t  * GGML_RESTRICT q8_1 = vy1[i].qs;
2128                        svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
2129                        svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
2130                        svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
2131                        svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
2132                        svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
2133                        lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
2134                        hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
2135                        sum_tmp1 = svuzp1(lo, hi);
2136                        sum_tmp2 = svuzp2(lo, hi);
2137                        svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
2138                        svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
2139                        svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
2140                        svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
2141                        svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
2142                        svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
2143                        svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
2144                        svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
2145                        svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
2146                        svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
2147                        svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
2148                        svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
2149                        svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
2150                        svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
2151                        svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
2152                        svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
2153                        svint32_t svscales, sumi1, sumi2;
2154                        svint32_t acc_sumif1 = svdup_n_s32(0);
2155                        svint32_t acc_sumif2 = svdup_n_s32(0);
2156                        svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
2157                                 q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
2158#pragma GCC unroll 1
2159                        for (int j = 0; j < QK_K/64; ++j) {
2160                            q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
2161                            q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
2162                            q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
2163                            q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
2164                            l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2165                            l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2166                            l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2167                            l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2168                            q8bytes_0_h = svld1_s8(pg128_all, q8_0);
2169                            q8bytes_1_h = svld1_s8(pg128_all, q8_1);
2170                            q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
2171                            q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
2172                            r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2173                            r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2174                            r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2175                            r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2176                            sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
2177                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
2178                            acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
2179
2180                            q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
2181                            q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
2182                            q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
2183                            q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
2184                            l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2185                            l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2186                            l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2187                            l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2188                            q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
2189                            q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
2190                            q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
2191                            q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
2192                            r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2193                            r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2194                            r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2195                            r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2196                            sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
2197                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
2198                            acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
2199                            q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
2200                        }
2201                        sumf1 = svmla_f32_x(pg128_all,
2202                                svmla_f32_x(pg128_all,
2203                                    sumf1,
2204                                    svcvt_f32_x(pg128_all,
2205                                        svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
2206                                    svsuper_block_scales),
2207                                svdmins,
2208                                svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
2209                    }  //end of for nb
2210                } // end of case 128
2211                break;
2212            case 256:
2213            case 512:
2214                {
2215                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2216                    const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
2217                    const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
2218                    for (int i = 0; i < nb; ++i) {
2219                        const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
2220                        const int8_t  * GGML_RESTRICT q8_0 = vy0[i].qs;
2221                        const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
2222                        const int8_t  * GGML_RESTRICT q8_1 = vy1[i].qs;
2223                        svint32_t svscales, sumi1, sumi2;
2224                        svint32_t acc_sumif1 = svdup_n_s32(0);
2225                        svint32_t acc_sumif2 = svdup_n_s32(0);
2226                        svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
2227                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2228                        svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2229                        svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
2230                        svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
2231                        svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
2232                        svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2233                        svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
2234                        svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
2235                        svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
2236                        svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
2237                        svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
2238                        svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
2239                        svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
2240                        svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
2241                        svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
2242                        svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
2243                        svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
2244                        svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
2245                        svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
2246                        svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
2247                        svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
2248                        svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
2249                        svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
2250                        svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
2251                        svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
2252
2253#pragma GCC unroll 1
2254                        for (int j = 0; j < QK_K/64; ++j) {
2255                            svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
2256                            svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
2257                            svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
2258                            svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
2259                            l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
2260                            l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
2261                            l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
2262                            l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
2263                            svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
2264                            svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
2265                            svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
2266                            svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
2267                            r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2268                            r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2269                            r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
2270                            r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
2271                            sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
2272                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
2273                            acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
2274                            sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
2275                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
2276                            acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
2277                            q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
2278                        }
2279                        svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
2280                        svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
2281                        acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
2282                        sumf1 = svmla_f32_x(pg32_4,
2283                                svmla_f32_x(pg32_4,
2284                                    sumf1,
2285                                    svcvt_f32_x(pg32_4, acc_sumif),
2286                                    svsuper_block_scales),
2287                                svdmins,
2288                                svsumfs_tmp);
2289                    } // end of for nb
2290                } // end of case 256-512
2291                break;
2292            default:
2293                assert(false && "Unsupported vector length");
2294                break;
2295        }
2296
2297        svst1_f32(pg32_2, s, sumf1);
2298        svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
2299
2300        return;
2301    }
2302#elif defined(__ARM_FEATURE_MATMUL_INT8)
2303    if (nrc == 2) {
2304        const block_q4_K * GGML_RESTRICT x0 = x;
2305        const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
2306        const block_q8_K * GGML_RESTRICT y0 = y;
2307        const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
2308
2309        const uint8x16_t m4b = vdupq_n_u8(0x0f);
2310
2311        float32x4_t vfsum = vdupq_n_f32(0.0f);
2312
2313        for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
2314            const uint8_t * GGML_RESTRICT qx0 = x0->qs;
2315            const uint8_t * GGML_RESTRICT qx1 = x1->qs;
2316            const  int8_t * GGML_RESTRICT qy0 = y0->qs;
2317            const  int8_t * GGML_RESTRICT qy1 = y1->qs;
2318
2319            // decode scales and mins
2320            int8_t x0_scales[8], x1_scales[8];
2321            int16x8_t x0_mins, x1_mins;
2322            {
2323                uint32_t scales_mins[3];
2324                memcpy(scales_mins, x0->scales, 12);
2325                const uint32_t mins_0_3 = scales_mins[1] & kmask1;
2326                const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
2327                const uint32x2_t mins = {mins_0_3, mins_4_7};
2328                x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
2329                uint32_t scales[2];
2330                scales[0] = scales_mins[0] & kmask1; // scales 0~3
2331                scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
2332                memcpy(x0_scales, scales, 8);
2333            }
2334            {
2335                uint32_t scales_mins[3];
2336                memcpy(scales_mins, x1->scales, 12);
2337                const uint32_t mins_0_3 = scales_mins[1] & kmask1;
2338                const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
2339                const uint32x2_t mins = {mins_0_3, mins_4_7};
2340                x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
2341                uint32_t scales[2];
2342                scales[0] = scales_mins[0] & kmask1; // scales 0~3
2343                scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
2344                memcpy(x1_scales, scales, 8);
2345            }
2346
2347            int32x4_t visum = {0};
2348
2349            // process 64 data points per iteration, totally 256 data points
2350            for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
2351                const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
2352                const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
2353
2354                int8x16_t vx0[4], vx1[4];
2355                {
2356                    const uint8x16x2_t vv = vld1q_u8_x2(qx0);
2357                    vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
2358                    vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
2359                    vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
2360                    vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
2361                }
2362                {
2363                    const uint8x16x2_t vv = vld1q_u8_x2(qx1);
2364                    vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
2365                    vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
2366                    vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
2367                    vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
2368                }
2369
2370                // process 32 data points (share same block scale) per iteration
2371                for (int k = 0; k < 2; ++k) {
2372                    const int blk = j * 2 + k;
2373                    const int32x4_t block_scale = {
2374                        x0_scales[blk],
2375                        x0_scales[blk],
2376                        x1_scales[blk],
2377                        x1_scales[blk],
2378                    };
2379
2380                    int32x4_t vr = {0};
2381                    for (int l = 0; l < 2; ++l) {
2382                        const int idx = k * 2 + l;
2383                        const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
2384                        const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
2385                        const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
2386                        const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
2387                        const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
2388                        const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
2389                        const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
2390                        const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
2391                        vr = vmmlaq_s32(vr, vx_l, vy_l);
2392                        vr = vmmlaq_s32(vr, vx_h, vy_h);
2393                    }
2394                    // apply block scale, will NOT overflow
2395                    // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
2396                    visum = vmlaq_s32(visum, vr, block_scale);
2397                }
2398            }
2399
2400            // adjust bias, apply superblock scale
2401            {
2402                int32_t bias[4];
2403                // no obvious uplift from sve sdot-16, just use neon mul add
2404                const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
2405                const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
2406                bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
2407                                               vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
2408                bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
2409                                               vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
2410                bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
2411                                               vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
2412                bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
2413                                               vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
2414                const float32x4_t dmins = {
2415                    GGML_CPU_FP16_TO_FP32(x0->dmin) * y0->d,
2416                    GGML_CPU_FP16_TO_FP32(x0->dmin) * y1->d,
2417                    GGML_CPU_FP16_TO_FP32(x1->dmin) * y0->d,
2418                    GGML_CPU_FP16_TO_FP32(x1->dmin) * y1->d,
2419                };
2420                vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
2421
2422                const float32x4_t superblock_scale = {
2423                    GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
2424                    GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
2425                    GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
2426                    GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
2427                };
2428                vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
2429            }
2430        }
2431
2432        // vfsum = ABCD -> ACBD
2433        // AC -> s, BD -> (s+bs)
2434        vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
2435        vst1_f32(s,      vget_low_f32 (vfsum));
2436        vst1_f32(s + bs, vget_high_f32(vfsum));
2437
2438        return;
2439    }
2440#endif
2441
2442#ifdef __ARM_FEATURE_SVE
2443    float sumf = 0;
2444    for (int i = 0; i < nb; ++i) {
2445
2446        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2447        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2448
2449        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2450
2451        memcpy(utmp, x[i].scales, K_SCALE_SIZE);
2452
2453        uint32x2_t mins8 = { 0 };
2454        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
2455        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
2456
2457        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2458        utmp[0] &= kmask1;
2459
2460        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
2461        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2462                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2463        sumf -= dmin * vaddvq_s32(prod);
2464
2465        const uint8_t * scales = (const uint8_t *)utmp;
2466
2467        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
2468        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2469
2470        const svuint8_t m4b = svdup_n_u8(0xf);
2471        const svint32_t mzero = svdup_n_s32(0);
2472        svint32_t sumi1 = svdup_n_s32(0);
2473        svint32_t sumi1_1 = svdup_n_s32(0);
2474        svint32_t sumi1_2 = svdup_n_s32(0);
2475        svint32_t sumi2 = svdup_n_s32(0);
2476        svint32_t sumi2_1 = svdup_n_s32(0);
2477        svint32_t sumi2_2 = svdup_n_s32(0);
2478        switch (vector_length) {
2479            case 128:
2480                {
2481                    for (int j = 0; j < QK_K/64; ++j) {
2482                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
2483                        svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2484                        sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2485                        q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
2486                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2487                        sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2488
2489                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
2490                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2491                        sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2492                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
2493                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2494                        sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2495                        q4 += 32;
2496                    }
2497                    sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
2498                    sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
2499                    sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
2500                } break;
2501            case 256:
2502            case 512:
2503                {
2504                    for (int j = 0; j < QK_K/64; ++j) {
2505                        const svuint8_t q4bits  = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
2506                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
2507                        svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
2508                        sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2509
2510                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
2511                        q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
2512                        sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2513                    }
2514                    sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
2515                } break;
2516            default:
2517                assert(false && "Unsupported vector length");
2518                break;
2519        }
2520    }
2521    *s = sumf;
2522#elif defined __ARM_NEON
2523    const uint8x16_t m4b = vdupq_n_u8(0xf);
2524    const int32x4_t mzero = vdupq_n_s32(0);
2525
2526    ggml_int8x16x2_t q4bytes;
2527    ggml_int8x16x2_t q8bytes;
2528
2529    float sumf = 0;
2530
2531    for (int i = 0; i < nb; ++i) {
2532
2533        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2534        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2535
2536        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2537
2538        memcpy(utmp, x[i].scales, 12);
2539
2540        uint32x2_t mins8 = { 0 };
2541        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
2542        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
2543
2544        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2545        utmp[0] &= kmask1;
2546
2547        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
2548        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2549                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2550        sumf -= dmin * vaddvq_s32(prod);
2551
2552        const uint8_t * scales = (const uint8_t *)utmp;
2553
2554        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
2555        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2556
2557        int32_t sumi1 = 0;
2558        int32_t sumi2 = 0;
2559
2560        for (int j = 0; j < QK_K/64; ++j) {
2561            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
2562
2563            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
2564            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
2565            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
2566
2567            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2568            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
2569
2570            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
2571            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2572            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2573
2574            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2575
2576            sumi2 += vaddvq_s32(p2) * scales[2*j+1];
2577        }
2578
2579        sumf += d * (sumi1 + sumi2);
2580
2581    }
2582
2583    *s = sumf;
2584
2585#else
2586    UNUSED(x);
2587    UNUSED(y);
2588    UNUSED(nb);
2589    UNUSED(kmask1);
2590    UNUSED(kmask2);
2591    UNUSED(kmask3);
2592    UNUSED(utmp);
2593    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2594#endif
2595}
2596
2597void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
2598    assert(n % QK_K == 0);
2599    assert(nrc == 1);
2600    UNUSED(nrc);
2601    UNUSED(bx);
2602    UNUSED(by);
2603    UNUSED(bs);
2604
2605    const block_q5_K * GGML_RESTRICT x = vx;
2606    const block_q8_K * GGML_RESTRICT y = vy;
2607
2608    const int nb = n / QK_K;
2609
2610    static const uint32_t kmask1 = 0x3f3f3f3f;
2611    static const uint32_t kmask2 = 0x0f0f0f0f;
2612    static const uint32_t kmask3 = 0x03030303;
2613
2614    uint32_t utmp[4];
2615
2616
2617#ifdef __ARM_NEON
2618    const uint8x16_t m4b = vdupq_n_u8(0xf);
2619    const uint8x16_t mone = vdupq_n_u8(1);
2620    const uint8x16_t mtwo = vdupq_n_u8(2);
2621    const int32x4_t mzero = vdupq_n_s32(0);
2622
2623    ggml_int8x16x4_t q5bytes;
2624
2625    float sumf = 0;
2626
2627    for (int i = 0; i < nb; ++i) {
2628
2629        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2630        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2631
2632        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2633
2634        memcpy(utmp, x[i].scales, 12);
2635        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2636        const uint32_t uaux = utmp[1] & kmask1;
2637        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2638        utmp[2] = uaux;
2639        utmp[0] &= kmask1;
2640
2641        const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
2642        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
2643        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2644                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2645        int32_t sumi_mins = vaddvq_s32(prod);
2646
2647        const uint8_t * scales = (const uint8_t *)utmp;
2648
2649        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
2650        const uint8_t * GGML_RESTRICT qh = x[i].qh;
2651        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2652
2653        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
2654
2655        ggml_uint8x16x4_t q5h;
2656
2657        int32_t sumi = 0;
2658
2659        for (int j = 0; j < QK_K/64; ++j) {
2660
2661            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
2662            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
2663
2664            q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
2665            q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
2666            q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
2667            q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
2668            qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
2669            qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
2670
2671            q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
2672            q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
2673            q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
2674            q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
2675
2676            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
2677            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
2678        }
2679
2680        sumf += d * sumi - dmin * sumi_mins;
2681    }
2682
2683    *s = sumf;
2684
2685#else
2686    UNUSED(x);
2687    UNUSED(y);
2688    UNUSED(nb);
2689    UNUSED(kmask1);
2690    UNUSED(kmask2);
2691    UNUSED(kmask3);
2692    UNUSED(utmp);
2693    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2694#endif
2695}
2696
2697void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2698    assert(n % QK_K == 0);
2699#ifdef __ARM_FEATURE_MATMUL_INT8
2700    assert((nrc == 2) || (nrc == 1));
2701#else
2702    assert(nrc == 1);
2703#endif
2704    UNUSED(nrc);
2705    UNUSED(bx);
2706    UNUSED(by);
2707    UNUSED(bs);
2708
2709    const block_q6_K * GGML_RESTRICT x = vx;
2710    const block_q8_K * GGML_RESTRICT y = vy;
2711
2712    const int nb = n / QK_K;
2713
2714#ifdef __ARM_FEATURE_SVE
2715    const int vector_length = ggml_cpu_get_sve_cnt()*8;
2716#endif
2717#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2718    if (nrc == 2) {
2719        const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
2720
2721        svfloat32_t sum = svdup_n_f32(0);
2722
2723        const block_q6_K * GGML_RESTRICT vx0 = vx;
2724        const block_q8_K * GGML_RESTRICT vy0 = vy;
2725        const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
2726        const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
2727
2728        switch (vector_length) {
2729            case 128:
2730                {
2731                    const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
2732                    for (int i = 0; i < nb; ++i) {
2733                        const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
2734                        const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
2735                        const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
2736                        const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
2737                        const int8_t  * GGML_RESTRICT q80 = vy0[i].qs;
2738                        const int8_t  * GGML_RESTRICT q81 = vy1[i].qs;
2739
2740                        const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
2741                        const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
2742
2743                        svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2744                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2745                        svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
2746                        // process q8sum summation 128 bit route
2747                        const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
2748                        const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
2749                        const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
2750                        const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
2751                        const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
2752                        const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
2753                        const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
2754                        const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
2755                        const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
2756                        const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
2757                        const svint64_t prod = svdup_n_s64(0);
2758
2759                        svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
2760                        svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
2761                        svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
2762                        svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
2763                        svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
2764                        svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
2765                        svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
2766                        svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
2767                        svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
2768
2769                        // process mmla
2770                        svint8_t  l0, l1, r0, r1;
2771                        svint32_t isum_tmp = svdup_n_s32(0);
2772                        for (int j = 0; j < QK_K/128; ++j) {
2773                            for (int k = 0; k < 8; ++k) {
2774                                svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
2775                                svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
2776                                svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
2777                                svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
2778                                const int ql_pos = (k/4)*4;
2779                                svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
2780                                svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
2781                                const int qh_pos = (k/2)*2;
2782                                svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
2783                                svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
2784                                svint8_t  q6bytes_0, q6bytes_1;
2785                                if (qh_pos <= 4) {
2786                                    q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
2787                                    q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
2788                                } else {
2789                                    q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
2790                                    q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
2791                                }
2792                                svint8_t  q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
2793                                svint8_t  q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
2794                                l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2795                                l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2796                                r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2797                                r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2798                                svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
2799                                isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
2800                            }
2801                            qh0 += 32;  qh1 += 32;
2802                            ql0 += 64;  ql1 += 64;
2803                            q80 += 128; q81 += 128;
2804                            scale0 += 8; scale1 += 8;
2805                        }
2806                        sum = svmla_f32_x(pg128_all, sum,
2807                                svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
2808                                        svisum_mins, svdup_n_s32(-32))),
2809                                svsuper_block_scales);
2810                    }
2811                } // end of case 128
2812                break;
2813            case 256:
2814            case 512:
2815                {
2816                    const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
2817                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2818                    for (int i = 0; i < nb; ++i) {
2819                        const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
2820                        const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
2821                        const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
2822                        const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
2823                        const int8_t  * GGML_RESTRICT q80 = vy0[i].qs;
2824                        const int8_t  * GGML_RESTRICT q81 = vy1[i].qs;
2825
2826                        const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
2827                        const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
2828                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2829                        svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2830                        svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
2831                        svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
2832                        // process q8sum summation 256 bit route
2833                        const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
2834                        const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
2835                        const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
2836                        const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
2837                        const svint64_t prod = svdup_n_s64(0);
2838                        svint32_t isum_tmp1  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
2839                        svint32_t isum_tmp2  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
2840                        svint32_t isum_tmp3  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
2841                        svint32_t isum_tmp4  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
2842                        svint32_t isum_tmp5  = svtrn1_s32(isum_tmp1, isum_tmp2);
2843                        svint32_t isum_tmp6  = svtrn1_s32(isum_tmp3, isum_tmp4);
2844                        svint32_t isum_tmp7  = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
2845                        svint32_t isum_tmp8  = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
2846                        svint32_t isum_tmp9  = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
2847                        svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
2848                        svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
2849
2850                        // process mmla
2851                        svint8_t l0, l1, r0, r1;
2852                        svint32_t isum_tmp = svdup_n_s32(0);
2853                        for (int j = 0; j < QK_K/128; ++j) {
2854                            for (int k = 0; k < 8; k+=2) { // process 2 block
2855                                svuint8_t qhbits_0  = svld1_u8(pg256_all, qh0);
2856                                svuint8_t qhbits_1  = svld1_u8(pg256_all, qh1);
2857                                svuint8_t q6bits_0  = svld1_u8(pg256_all, ql0+32*((k%4)/2));
2858                                svuint8_t q6bits_1  = svld1_u8(pg256_all, ql1+32*((k%4)/2));
2859                                const int ql_pos = (k/4)*4;
2860                                svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
2861                                svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
2862                                const int qh_pos = (k/2)*2;
2863                                svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
2864                                svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
2865                                svint8_t  q6bytes_0, q6bytes_1;
2866                                if (qh_pos <= 4) {
2867                                    q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
2868                                    q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
2869                                } else {
2870                                    q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
2871                                    q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
2872                                }
2873                                svint8_t  q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
2874                                svint8_t  q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
2875                                l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2876                                l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2877                                r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2878                                r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2879                                svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
2880                                svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
2881                                isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
2882                                isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
2883                            }
2884                            qh0 += 32;  qh1 += 32;
2885                            ql0 += 64;  ql1 += 64;
2886                            q80 += 128; q81 += 128;
2887                            scale0 += 8; scale1 += 8;
2888                        } // end of for
2889                        svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
2890                        isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
2891                        sum = svmla_f32_x(pg32_4, sum,
2892                                svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
2893                                        svisum_mins, svdup_n_s32(-32))),
2894                                svsuper_block_scales);
2895                    }
2896                } // end of case 256
2897                break;
2898            default:
2899                assert(false && "Unsupported vector length");
2900                break;
2901        } // end of switch
2902
2903        svst1_f32(pg32_2, s, sum);
2904        svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
2905
2906        return;
2907    }
2908#elif defined(__ARM_FEATURE_MATMUL_INT8)
2909    if (nrc == 2) {
2910        const block_q6_K * GGML_RESTRICT x0 = x;
2911        const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
2912        const block_q8_K * GGML_RESTRICT y0 = y;
2913        const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
2914
2915        float32x4_t vfsum = vdupq_n_f32(0.0f);
2916
2917        for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
2918            const uint8_t * GGML_RESTRICT ql0 = x0->ql;
2919            const uint8_t * GGML_RESTRICT ql1 = x1->ql;
2920            const uint8_t * GGML_RESTRICT qh0 = x0->qh;
2921            const uint8_t * GGML_RESTRICT qh1 = x1->qh;
2922            const  int8_t * GGML_RESTRICT qy0 = y0->qs;
2923            const  int8_t * GGML_RESTRICT qy1 = y1->qs;
2924
2925            const uint8x16_t mone = vdupq_n_u8(0x30);
2926            const uint8x16_t  m4b = vdupq_n_u8(0x0f);
2927
2928            int32x4_t visum = vdupq_n_s32(0);
2929
2930            // process 8 blocks per iteration, totally 16 blocks
2931            for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
2932                int8x16_t vx0[8], vx1[8];
2933
2934                // de-quantize vx0[8]
2935                {
2936                    const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
2937                    const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
2938
2939                    uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
2940                    uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
2941                    uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
2942                    uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
2943
2944                    vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
2945                    vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
2946                    vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
2947                    vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
2948
2949                    q6h_0 = vandq_u8(mone, qh_bits.val[0]);
2950                    q6h_1 = vandq_u8(mone, qh_bits.val[1]);
2951                    q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
2952                    q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
2953
2954                    vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
2955                    vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
2956                    vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
2957                    vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
2958                }
2959
2960                // de-quantize vx1[8]
2961                {
2962                    const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
2963                    const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
2964
2965                    uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
2966                    uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
2967                    uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
2968                    uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
2969
2970                    vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
2971                    vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
2972                    vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
2973                    vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
2974
2975                    q6h_0 = vandq_u8(mone, qh_bits.val[0]);
2976                    q6h_1 = vandq_u8(mone, qh_bits.val[1]);
2977                    q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
2978                    q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
2979
2980                    vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
2981                    vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
2982                    vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
2983                    vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
2984                }
2985
2986                // process 16 elements (one block with same scale) per iteration
2987                // - vx = concat(ql, qh) - 32
2988                // - r1,r2,r3,r4 = smmla(vx, vy)
2989                for (int k = 0; k < 8; ++k) {
2990                    const int blk = j * 8 + k;
2991
2992                    const int8x16_t vy0 = vld1q_s8(qy0);
2993                    const int8x16_t vy1 = vld1q_s8(qy1);
2994                    qy0 += 16;
2995                    qy1 += 16;
2996
2997                    const int32x4_t block_scale = {
2998                        x0->scales[blk],
2999                        x0->scales[blk],
3000                        x1->scales[blk],
3001                        x1->scales[blk],
3002                    };
3003
3004                    // calculate four results at once with outer product
3005                    const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
3006                    const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
3007                    const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
3008                    const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
3009                    int32x4_t vr = vdupq_n_s32(0);
3010                    vr = vmmlaq_s32(vr, vx_l, vy_l);
3011                    vr = vmmlaq_s32(vr, vx_h, vy_h);
3012
3013                    // apply block scale, will NOT overflow
3014                    // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
3015                    visum = vmlaq_s32(visum, vr, block_scale);
3016                }
3017            }
3018
3019            // adjust bias, apply superblock scale
3020            {
3021                int32_t bias[4];
3022                // NEON doesn't support int16 dot product, fallback to separated mul and add
3023                const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
3024                const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
3025
3026                int8x16_t scales_s8 = vld1q_s8(x0->scales);
3027                const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
3028                scales_s8 = vld1q_s8(x1->scales);
3029                const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
3030
3031                int32x4_t prod;
3032                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
3033                                           vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
3034                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
3035                                           vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
3036                bias[0] = vaddvq_s32(prod);
3037                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
3038                                           vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
3039                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
3040                                           vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
3041                bias[1] = vaddvq_s32(prod);
3042                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
3043                                           vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
3044                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
3045                                           vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
3046                bias[2] = vaddvq_s32(prod);
3047                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
3048                                           vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
3049                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
3050                                           vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
3051                bias[3] = vaddvq_s32(prod);
3052
3053                const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
3054
3055                const float32x4_t superblock_scale = {
3056                    GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
3057                    GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
3058                    GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
3059                    GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
3060                };
3061
3062                visum = vsubq_s32(visum, vibias);
3063                vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
3064            }
3065        }
3066
3067        // vfsum = ABCD -> ACBD
3068        // AC -> s, BD -> (s+bs)
3069        vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
3070        vst1_f32(s,      vget_low_f32 (vfsum));
3071        vst1_f32(s + bs, vget_high_f32(vfsum));
3072
3073        return;
3074    }
3075#endif
3076
3077#ifdef __ARM_FEATURE_SVE
3078    float sum = 0;
3079    svuint8_t m4b = svdup_n_u8(0xf);
3080    svint32_t vzero = svdup_n_s32(0);
3081    svuint8_t mone = svdup_n_u8(0x30);
3082    svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
3083    svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
3084
3085    for (int i = 0; i < nb; ++i) {
3086        const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);
3087
3088        const uint8_t * GGML_RESTRICT q6 = x[i].ql;
3089        const uint8_t * GGML_RESTRICT qh = x[i].qh;
3090        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
3091
3092        const int8_t * GGML_RESTRICT scale = x[i].scales;
3093
3094        const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
3095        const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
3096        const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
3097        const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
3098        const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
3099        const svint64_t prod = svdup_n_s64(0);
3100        int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
3101                                                                                 svdot_s64(prod, q8sums_2, q6scales_2)));
3102        int32_t isum = 0;
3103
3104        switch (vector_length) {
3105            case 128:
3106                {
3107                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
3108                    const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
3109                    svint32_t isum_tmp = svdup_n_s32(0);
3110                    for (int j = 0; j < QK_K/128; ++j) {
3111                        svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
3112                        svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
3113                        qh += 32;
3114                        svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
3115                        svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
3116                        svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
3117                        svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
3118                        q6 += 64;
3119                        svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
3120                        svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
3121                        svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
3122                        svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
3123                        q8 += 64;
3124
3125                        q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
3126                        q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
3127                        q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
3128                        q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
3129                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
3130                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
3131                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
3132                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
3133                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
3134                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
3135                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
3136                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
3137
3138                        scale += 4;
3139                        q8bytes_1 = svld1_s8(pg8_16, q8);
3140                        q8bytes_2 = svld1_s8(pg8_16, q8+16);
3141                        q8bytes_3 = svld1_s8(pg8_16, q8+32);
3142                        q8bytes_4 = svld1_s8(pg8_16, q8+48);
3143                        q8 += 64;
3144
3145                        q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
3146                        q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
3147                        q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
3148                        q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
3149                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
3150                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
3151                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
3152                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
3153                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
3154                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
3155                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
3156                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
3157                        scale += 4;
3158                    }
3159                    isum += svaddv_s32(pg32_4, isum_tmp);
3160                    sum += d_all * y[i].d * (isum - 32 * isum_mins);
3161                }
3162                break;
3163            case 256:
3164            case 512:
3165                {
3166                    const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
3167                    const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
3168                    const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
3169                    svint32_t isum_tmp = svdup_n_s32(0);
3170                    for (int j = 0; j < QK_K/128; j++) {
3171                        svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
3172                        qh += 32;
3173                        svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
3174                        svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
3175                        q6 += 64;
3176                        svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
3177                        svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
3178                        svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
3179                        svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
3180                        q8 += 128;
3181                        q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
3182                        q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
3183                        q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
3184                        q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
3185                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
3186                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
3187                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
3188                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
3189
3190                        svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
3191                        scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
3192                        scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
3193                        svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
3194                        scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
3195                        scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
3196                        svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
3197                        scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
3198                        scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
3199                        svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
3200                        scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
3201                        scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
3202                        svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
3203                        svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
3204                        svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
3205                        svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
3206
3207                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
3208                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
3209                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
3210                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
3211                        scale += 8;
3212                    }
3213                    isum += svaddv_s32(pg32_8, isum_tmp);
3214                    sum += d_all * y[i].d * (isum - 32 * isum_mins);
3215                }
3216                break;
3217            default:
3218                assert(false && "Unsupported vector length");
3219                break;
3220        }
3221    }
3222
3223    *s = sum;
3224
3225#elif __ARM_NEON
3226    float sum = 0;
3227
3228    const uint8x16_t m4b = vdupq_n_u8(0xF);
3229    const int32x4_t  vzero = vdupq_n_s32(0);
3230    //const int8x16_t  m32s = vdupq_n_s8(32);
3231
3232    const uint8x16_t mone = vdupq_n_u8(3);
3233
3234    ggml_int8x16x4_t q6bytes;
3235    ggml_uint8x16x4_t q6h;
3236
3237    for (int i = 0; i < nb; ++i) {
3238
3239        const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);
3240
3241        const uint8_t * GGML_RESTRICT q6 = x[i].ql;
3242        const uint8_t * GGML_RESTRICT qh = x[i].qh;
3243        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
3244
3245        const int8_t * GGML_RESTRICT scale = x[i].scales;
3246
3247        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
3248        const int8x16_t scales = vld1q_s8(scale);
3249        const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
3250
3251        const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
3252                                                   vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
3253                                         vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
3254                                                   vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
3255        int32_t isum_mins = vaddvq_s32(prod);
3256
3257        int32_t isum = 0;
3258
3259        for (int j = 0; j < QK_K/128; ++j) {
3260
3261            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
3262            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
3263            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
3264
3265            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
3266            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
3267            uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
3268            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3269            shifted = vshrq_n_u8(qhbits.val[1], 2);
3270            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3271
3272            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
3273            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
3274            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
3275            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
3276            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
3277            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
3278            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
3279            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
3280
3281            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3282                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3283                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3284                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3285
3286            scale += 4;
3287
3288            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
3289
3290            shifted = vshrq_n_u8(qhbits.val[0], 4);
3291            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3292            shifted = vshrq_n_u8(qhbits.val[1], 4);
3293            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3294            shifted = vshrq_n_u8(qhbits.val[0], 6);
3295            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3296            shifted = vshrq_n_u8(qhbits.val[1], 6);
3297            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3298
3299            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
3300            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
3301            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
3302            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
3303            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
3304            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
3305            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
3306            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
3307
3308            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3309                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3310                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3311                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3312            scale += 4;
3313        }
3314        //sum += isum * d_all * y[i].d;
3315        sum += d_all * y[i].d * (isum - 32 * isum_mins);
3316
3317    }
3318    *s = sum;
3319#else
3320    UNUSED(x);
3321    UNUSED(y);
3322    UNUSED(nb);
3323    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3324#endif
3325}
3326
3327#if defined (__ARM_NEON)
3328static const int8_t keven_signs_q2xs[1024] = {
3329     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
3330     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
3331     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
3332     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
3333     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
3334     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
3335     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
3336     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
3337     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
3338     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
3339     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
3340     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
3341     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
3342     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
3343     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
3344     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
3345     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
3346     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
3347     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
3348     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
3349     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
3350     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
3351     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
3352     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
3353     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
3354     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
3355     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
3356     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
3357     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
3358     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
3359     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
3360     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
3361};
3362#endif
3363
3364void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3365    assert(n % QK_K == 0);
3366    assert(nrc == 1);
3367    UNUSED(nrc);
3368    UNUSED(bx);
3369    UNUSED(by);
3370    UNUSED(bs);
3371
3372    const block_iq2_xxs * GGML_RESTRICT x = vx;
3373    const block_q8_K    * GGML_RESTRICT y = vy;
3374
3375    const int nb = n / QK_K;
3376
3377#if defined(__ARM_NEON)
3378
3379    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3380
3381    uint32_t aux32[4];
3382    const uint8_t * aux8 = (const uint8_t *)aux32;
3383
3384    ggml_int8x16x4_t q2u;
3385    ggml_int8x16x4_t q2s;
3386    ggml_int8x16x4_t q8b;
3387
3388    float sumf = 0;
3389    for (int i = 0; i < nb; ++i) {
3390        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3391        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
3392        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
3393        float sumf1 = 0, sumf2 = 0;
3394        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3395            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3396            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
3397            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
3398            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
3399            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
3400            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
3401            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
3402            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
3403            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >>  7) & 127))));
3404            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
3405            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
3406            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
3407            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
3408            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
3409            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
3410            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
3411            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
3412            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
3413        }
3414        sumf += d*(sumf1 + sumf2);
3415    }
3416    *s = 0.25f * sumf;
3417
3418#else
3419    UNUSED(x);
3420    UNUSED(y);
3421    UNUSED(nb);
3422    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3423#endif
3424}
3425
3426void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3427    assert(n % QK_K == 0);
3428    assert(nrc == 1);
3429    UNUSED(nrc);
3430    UNUSED(bx);
3431    UNUSED(by);
3432    UNUSED(bs);
3433
3434    const block_iq2_xs * GGML_RESTRICT x = vx;
3435    const block_q8_K   * GGML_RESTRICT y = vy;
3436
3437    const int nb = n / QK_K;
3438
3439#if defined(__ARM_NEON)
3440
3441    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3442
3443    ggml_int8x16x4_t q2u;
3444    ggml_int8x16x4_t q2s;
3445    ggml_int8x16x4_t q8b;
3446
3447    int32x4x4_t scales32;
3448
3449    float sumf = 0;
3450    for (int i = 0; i < nb; ++i) {
3451        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3452        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
3453        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
3454        const uint8x8_t scales8 = vld1_u8(x[i].scales);
3455        const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
3456        const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
3457        uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
3458        scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
3459        const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
3460        const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
3461        scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
3462        scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
3463        scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
3464        scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
3465        int32x4_t sumi = vdupq_n_s32(0);
3466        for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
3467            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3468            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
3469            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
3470            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
3471            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
3472            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
3473            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
3474            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
3475            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
3476            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
3477            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
3478            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
3479            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
3480            const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
3481            const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
3482            const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
3483            const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
3484            const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
3485            sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
3486            q2 += 8;
3487        }
3488        sumf += d*vaddvq_s32(sumi);
3489    }
3490    *s = 0.125f * sumf;
3491
3492#else
3493    UNUSED(x);
3494    UNUSED(y);
3495    UNUSED(nb);
3496    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3497#endif
3498}
3499
3500void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3501    assert(n % QK_K == 0);
3502    assert(nrc == 1);
3503    UNUSED(nrc);
3504    UNUSED(bx);
3505    UNUSED(by);
3506    UNUSED(bs);
3507
3508    const block_iq2_s * GGML_RESTRICT x = vx;
3509    const block_q8_K  * GGML_RESTRICT y = vy;
3510
3511    const int nb = n / QK_K;
3512
3513#if defined(__ARM_NEON)
3514
3515   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3516                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3517   };
3518
3519    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
3520
3521    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
3522    const uint8x16_t        mask2 = vld1q_u8(k_mask2);
3523    const uint8x16_t m1 = vdupq_n_u8(1);
3524    const int32x4_t vzero = vdupq_n_s32(0);
3525
3526    uint8x16x2_t vs;
3527    ggml_int8x16x4_t q2s;
3528    ggml_int8x16x4_t q8b;
3529
3530    float sumf = 0;
3531    for (int i = 0; i < nb; ++i) {
3532
3533        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3534
3535        const uint8_t * GGML_RESTRICT qs = x[i].qs;
3536        const uint8_t * GGML_RESTRICT qh = x[i].qh;
3537        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
3538        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
3539
3540        int sumi1 = 0, sumi2 = 0;
3541        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3542            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3543            q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
3544                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
3545            q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
3546                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
3547            q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
3548                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
3549            q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
3550                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
3551            qs += 8;
3552
3553            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
3554            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3555            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3556            vs.val[0] = vceqq_u8(vs.val[0], mask2);
3557            vs.val[1] = vceqq_u8(vs.val[1], mask2);
3558
3559            q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
3560            q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
3561
3562            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
3563            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3564            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3565            vs.val[0] = vceqq_u8(vs.val[0], mask2);
3566            vs.val[1] = vceqq_u8(vs.val[1], mask2);
3567
3568            signs += 4;
3569
3570            q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
3571            q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
3572
3573            const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
3574            const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
3575            const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
3576            const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
3577
3578            sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
3579            sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >>  4));
3580            sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
3581            sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >>  4));
3582        }
3583        sumf += d*(sumi1 + sumi2);
3584    }
3585
3586    *s = 0.125f * sumf;
3587
3588#else
3589    UNUSED(x);
3590    UNUSED(y);
3591    UNUSED(nb);
3592    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3593#endif
3594
3595}
3596
3597void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3598    assert(n % QK_K == 0);
3599    assert(nrc == 1);
3600    UNUSED(nrc);
3601    UNUSED(bx);
3602    UNUSED(by);
3603    UNUSED(bs);
3604
3605    const block_iq3_xxs * GGML_RESTRICT x = vx;
3606    const block_q8_K    * GGML_RESTRICT y = vy;
3607
3608    const int nb = n / QK_K;
3609
3610#if defined(__ARM_NEON)
3611
3612    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3613
3614    uint32_t aux32[2];
3615
3616    ggml_int8x16x4_t q3s;
3617    ggml_int8x16x4_t q8b;
3618
3619    float sumf = 0;
3620    for (int i = 0; i < nb; ++i) {
3621        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3622        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
3623        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
3624        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
3625        float sumf1 = 0, sumf2 = 0;
3626        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3627            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3628            memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
3629            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
3630            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
3631            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
3632            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
3633            q3 += 16;
3634            q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >>  7) & 127))));
3635            q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
3636            q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
3637            q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
3638            q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
3639            q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
3640            q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
3641            q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
3642            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
3643            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
3644            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
3645            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
3646        }
3647        sumf += d*(sumf1 + sumf2);
3648    }
3649    *s = 0.5f * sumf;
3650
3651#else
3652    UNUSED(x);
3653    UNUSED(y);
3654    UNUSED(nb);
3655    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3656#endif
3657}
3658
3659void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3660    assert(n % QK_K == 0);
3661    assert(nrc == 1);
3662    UNUSED(nrc);
3663    UNUSED(bx);
3664    UNUSED(by);
3665    UNUSED(bs);
3666
3667    const block_iq3_s * GGML_RESTRICT x = vx;
3668    const block_q8_K  * GGML_RESTRICT y = vy;
3669
3670    const int nb = n / QK_K;
3671
3672#if defined(__ARM_NEON)
3673
3674    typedef union {
3675        uint16x8_t vec_index;
3676        uint16_t   index[8];
3677    } vec_index_t;
3678
3679   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3680                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3681   };
3682
3683    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
3684
3685    static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
3686
3687    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
3688    const uint8x16_t        mask2 = vld1q_u8(k_mask2);
3689
3690    const int16x8_t  hshift = vld1q_s16(k_shift);
3691    const uint16x8_t m256   = vdupq_n_u16(256);
3692    const uint8x16_t m1     = vdupq_n_u8(1);
3693
3694    uint8x16x2_t vs;
3695    ggml_int8x16x4_t q3s;
3696    ggml_int8x16x4_t q8b;
3697    vec_index_t idx;
3698
3699    uint32_t scales32[2];
3700    const uint8_t * scales8 = (const uint8_t *)scales32;
3701
3702    float sumf = 0;
3703    for (int i = 0; i < nb; ++i) {
3704        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3705        const uint8_t * GGML_RESTRICT qs = x[i].qs;
3706        const uint8_t * GGML_RESTRICT qh = x[i].qh;
3707        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3708        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
3709
3710        memcpy(scales32, x[i].scales, 4);
3711        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
3712        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
3713
3714        int sumi1 = 0, sumi2 = 0;
3715        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3716            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3717
3718            const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
3719            idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
3720            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
3721                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
3722            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
3723                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
3724            idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
3725            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
3726                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
3727            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
3728                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
3729
3730
3731            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
3732            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3733            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3734            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
3735            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
3736
3737            q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
3738            q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
3739
3740            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
3741            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3742            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3743            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
3744            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
3745
3746            signs += 4;
3747
3748            q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
3749            q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
3750
3751            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
3752            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
3753
3754            sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
3755            sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
3756        }
3757        sumf += d*(sumi1 + sumi2);
3758    }
3759    *s = sumf;
3760
3761#else
3762    UNUSED(x);
3763    UNUSED(y);
3764    UNUSED(nb);
3765    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3766#endif
3767}
3768
3769void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3770    assert(n % QK_K == 0);
3771    assert(nrc == 1);
3772    UNUSED(nrc);
3773    UNUSED(bx);
3774    UNUSED(by);
3775    UNUSED(bs);
3776
3777    const block_iq1_s * GGML_RESTRICT x = vx;
3778    const block_q8_K  * GGML_RESTRICT y = vy;
3779
3780    const int nb = n / QK_K;
3781
3782#if defined __ARM_NEON
3783
3784    ggml_int8x16x4_t q1b;
3785    ggml_int8x16x4_t q8b;
3786
3787    float sumf = 0;
3788    for (int i = 0; i < nb; ++i) {
3789
3790        const int8_t   * q8 = y[i].qs;
3791        const uint8_t  * qs = x[i].qs;
3792        const uint16_t * qh = x[i].qh;
3793
3794        int sumi1 = 0, sumi2 = 0, sumi3 = 0;
3795
3796        for (int ib = 0; ib < QK_K/32; ib += 2) {
3797
3798            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
3799                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
3800            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
3801                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
3802            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
3803                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
3804            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
3805                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
3806            qs += 8;
3807
3808            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3809
3810            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
3811            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
3812
3813            const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3814            const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3815            sumi1 += vaddvq_s32(p1) * ls1;
3816            sumi2 += vaddvq_s32(p2) * ls2;
3817            sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
3818                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
3819
3820        }
3821
3822        sumf += y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
3823    }
3824
3825    *s = sumf;
3826
3827#else
3828    UNUSED(x);
3829    UNUSED(y);
3830    UNUSED(nb);
3831    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3832#endif
3833}
3834
3835void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3836    assert(n % QK_K == 0);
3837    assert(nrc == 1);
3838    UNUSED(nrc);
3839    UNUSED(bx);
3840    UNUSED(by);
3841    UNUSED(bs);
3842
3843    const block_iq1_m * GGML_RESTRICT x = vx;
3844    const block_q8_K  * GGML_RESTRICT y = vy;
3845
3846    const int nb = n / QK_K;
3847
3848    iq1m_scale_t scale;
3849
3850#if defined __ARM_NEON
3851    const int32x4_t mask  = vdupq_n_s32(0x7);
3852    const int32x4_t mone  = vdupq_n_s32(1);
3853    const int32x4_t mzero = vdupq_n_s32(0);
3854
3855    ggml_int8x16x4_t deltas;
3856    deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
3857    deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
3858    deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
3859    deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
3860
3861    ggml_int8x16x4_t q1b;
3862    ggml_int8x16x4_t q8b;
3863
3864    uint32_t aux32;
3865    const uint8_t * aux8 = (const uint8_t *)&aux32;
3866
3867    float sumf = 0;
3868    for (int i = 0; i < nb; ++i) {
3869
3870        const int8_t   * q8 = y[i].qs;
3871        const uint8_t  * qs = x[i].qs;
3872        const uint8_t  * qh = x[i].qh;
3873        const uint16_t * sc = (const uint16_t *)x[i].scales;
3874
3875        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3876
3877        int32x4_t sumi1 = mzero;
3878        int32x4_t sumi2 = mzero;
3879
3880        for (int ib = 0; ib < QK_K/32; ib += 2) {
3881
3882            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
3883                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
3884            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
3885                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
3886            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
3887                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
3888            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
3889                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
3890
3891            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3892
3893            const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
3894            const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
3895            const int32x4_t p12 = vpaddq_s32(p1, p2);
3896
3897            const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
3898            aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
3899
3900            const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
3901            const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
3902            const int32x4_t p34 = vpaddq_s32(p3, p4);
3903
3904            int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
3905
3906            scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
3907
3908            sumi1 = vmlaq_s32(sumi1, scales_4, p12);
3909            sumi2 = vmlaq_s32(sumi2, scales_4, p34);
3910
3911            qs += 8; qh += 4;
3912
3913        }
3914
3915        sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
3916    }
3917
3918    *s = sumf;
3919
3920#else
3921    UNUSED(x);
3922    UNUSED(y);
3923    UNUSED(nb);
3924    UNUSED(scale);
3925    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3926#endif
3927}
3928
3929void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3930    assert(nrc == 1);
3931    UNUSED(nrc);
3932    UNUSED(bx);
3933    UNUSED(by);
3934    UNUSED(bs);
3935    assert(n % QK4_NL == 0);
3936    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
3937
3938    const block_iq4_nl * GGML_RESTRICT x = vx;
3939    const block_q8_0   * GGML_RESTRICT y = vy;
3940
3941    const int nb = n / QK4_NL;
3942
3943    int ib = 0;
3944    float sumf = 0;
3945
3946#if defined __ARM_NEON
3947    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
3948    const uint8x16_t m4b = vdupq_n_u8(0x0f);
3949    uint8x16x2_t q4bits;
3950    int8x16x4_t q4b;
3951    int8x16x4_t q8b;
3952    int32x4_t prod_1, prod_2;
3953
3954    for (; ib + 1 < nb; ib += 2) {
3955
3956        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
3957        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
3958        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);
3959        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);
3960        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);
3961        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);
3962
3963        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
3964        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
3965        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
3966        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
3967
3968        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
3969        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
3970
3971        sumf +=
3972            GGML_CPU_FP16_TO_FP32(x[ib+0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
3973            GGML_CPU_FP16_TO_FP32(x[ib+1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
3974    }
3975
3976#endif
3977    for (; ib < nb; ++ib) {
3978        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
3979        int sumi1 = 0, sumi2 = 0;
3980        for (int j = 0; j < QK4_NL/2; ++j) {
3981            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
3982            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
3983        }
3984        sumf += d * (sumi1 + sumi2);
3985    }
3986    *s = sumf;
3987}
3988
3989void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3990    assert(nrc == 1);
3991    UNUSED(nrc);
3992    UNUSED(bx);
3993    UNUSED(by);
3994    UNUSED(bs);
3995    assert(n % QK_K == 0);
3996
3997    const block_iq4_xs * GGML_RESTRICT x = vx;
3998    const block_q8_K   * GGML_RESTRICT y = vy;
3999
4000    const int nb = n / QK_K;
4001
4002#if defined __ARM_NEON
4003    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
4004    const uint8x16_t m4b = vdupq_n_u8(0x0f);
4005    ggml_uint8x16x2_t q4bits;
4006    ggml_int8x16x4_t q4b;
4007    ggml_int8x16x4_t q8b;
4008    int32x4_t prod_1, prod_2;
4009
4010    float sumf = 0;
4011
4012    for (int ibl = 0; ibl < nb; ++ibl) {
4013
4014        const int8_t  * q8 = y[ibl].qs;
4015        const uint8_t * q4 = x[ibl].qs;
4016        uint16_t h = x[ibl].scales_h;
4017
4018        int sumi1 = 0, sumi2 = 0;
4019        for (int ib = 0; ib < QK_K/64; ++ib) {
4020
4021            q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
4022            q8b    = ggml_vld1q_s8_x4(q8); q8 += 64;
4023
4024            q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
4025            q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
4026            q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
4027            q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
4028
4029            prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
4030            prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
4031
4032            int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
4033            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;
4034            h >>= 4;
4035            sumi1 += vaddvq_s32(prod_1) * ls1;
4036            sumi2 += vaddvq_s32(prod_2) * ls2;
4037
4038        }
4039
4040        sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
4041    }
4042
4043    *s = sumf;
4044
4045#else
4046    UNUSED(x);
4047    UNUSED(y);
4048    UNUSED(nb);
4049    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4050#endif
4051}
4052