1#define GGML_COMMON_IMPL_C
   2#include "ggml-common.h"
   3#include "ggml-quants.h"
   4#include "ggml-impl.h"
   5#include "ggml-cpu.h"
   6#include "simd-mappings.h"
   7
   8#include "../../quants.h"
   9#include "../../ggml-cpu-impl.h"
  10
  11#include <math.h>
  12#include <string.h>
  13#include <assert.h>
  14#include <float.h>
  15#include <stdlib.h> // for qsort
  16#include <stdio.h>  // for GGML_ASSERT
  17
  18#define GROUP_MAX_EPS 1e-15f
  19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
  20#define GROUP_MAX_EPS_IQ2_S 1e-8f
  21#define GROUP_MAX_EPS_IQ1_M 1e-7f
  22#define GROUP_MAX_EPS_IQ1_S 1e-12f
  23
  24#define UNUSED GGML_UNUSED
  25
  26void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  27    assert(QK8_0 == 32);
  28    assert(k % QK8_0 == 0);
  29    const int nb = k / QK8_0;
  30
  31    block_q8_0 * GGML_RESTRICT y = vy;
  32
  33#if defined(__riscv_v)
  34
  35    size_t vl = QK8_0;
  36
  37    for (int i = 0; i < nb; i++) {
  38        // load elements
  39        vfloat32m8_t v_x   = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
  40
  41        vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
  42        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
  43        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
  44        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
  45
  46        const float d = amax / ((1 << 7) - 1);
  47        const float id = d ? 1.0f/d : 0.0f;
  48
  49        y[i].d = GGML_CPU_FP32_TO_FP16(d);
  50
  51        vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
  52
  53        // convert to integer
  54        vint16m4_t   vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
  55        vint8m2_t    vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
  56
  57        // store result
  58        __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
  59    }
  60#else
  61    GGML_UNUSED(nb);
  62    // scalar
  63    quantize_row_q8_0_ref(x, y, k);
  64#endif
  65}
  66
  67void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  68    assert(k % QK8_1 == 0);
  69    const int nb = k / QK8_1;
  70
  71    block_q8_1 * GGML_RESTRICT y = vy;
  72
  73#if defined(__riscv_v)
  74
  75    size_t vl = QK8_1;
  76
  77    for (int i = 0; i < nb; i++) {
  78        // load elements
  79        vfloat32m8_t v_x   = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
  80
  81        vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
  82        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
  83        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
  84        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
  85
  86        const float d  = amax / ((1 << 7) - 1);
  87        const float id = d ? 1.0f/d : 0.0f;
  88
  89        y[i].d = GGML_CPU_FP32_TO_FP16(d);
  90
  91        vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
  92
  93        // convert to integer
  94        vint16m4_t   vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
  95        vint8m2_t    vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
  96
  97        // store result
  98        __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
  99
 100        // compute sum for y[i].s
 101        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
 102        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
 103
 104        // set y[i].s
 105        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
 106        y[i].s = GGML_CPU_FP32_TO_FP16(sum*d);
 107    }
 108
 109#else
 110    GGML_UNUSED(nb);
 111    // scalar
 112    quantize_row_q8_1_ref(x, y, k);
 113#endif
 114}
 115
 116//===================================== Dot products =================================
 117
 118void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 119#if defined(__riscv_v)
 120    const int qk = QK8_0;
 121    const int nb = n / qk;
 122
 123    assert(n % qk == 0);
 124    assert(nrc == 1);
 125    UNUSED(nrc);
 126    UNUSED(bx);
 127    UNUSED(by);
 128    UNUSED(bs);
 129
 130    const block_q4_0 * GGML_RESTRICT x = vx;
 131    const block_q8_0 * GGML_RESTRICT y = vy;
 132
 133    int ib = 0;
 134    float sumf = 0;
 135
 136    size_t vl = qk / 2;
 137
 138    for (; ib < nb; ++ib) {
 139        // load elements
 140        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
 141
 142        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
 143        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
 144
 145        // mask and store lower part of x, and then upper part
 146        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
 147        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
 148
 149        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
 150        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
 151
 152        // subtract offset
 153        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
 154        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
 155
 156        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
 157        vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
 158
 159        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
 160        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
 161
 162        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 163
 164        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
 165    }
 166
 167    *s = sumf;
 168#else
 169    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 170#endif
 171}
 172
 173void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 174#if defined(__riscv_v)
 175    const int qk = QK8_1;
 176    const int nb = n / qk;
 177
 178    assert(n % qk == 0);
 179    assert(nrc == 1);
 180    UNUSED(nrc);
 181    UNUSED(bx);
 182    UNUSED(by);
 183    UNUSED(bs);
 184
 185    const block_q4_1 * GGML_RESTRICT x = vx;
 186    const block_q8_1 * GGML_RESTRICT y = vy;
 187
 188    int ib = 0;
 189    float sumf = 0;
 190
 191    size_t vl = qk / 2;
 192
 193    for (; ib < nb; ++ib) {
 194        // load elements
 195        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
 196
 197        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
 198        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
 199
 200        // mask and store lower part of x, and then upper part
 201        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
 202        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
 203
 204        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
 205        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
 206
 207        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
 208        vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
 209
 210        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
 211        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
 212
 213        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 214
 215        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
 216    }
 217
 218    *s = sumf;
 219#else
 220    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 221#endif
 222}
 223
 224void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 225#if defined(__riscv_v)
 226    const int qk = QK8_0;
 227    const int nb = n / qk;
 228
 229    int ib = 0;
 230    float sumf = 0;
 231
 232    assert(n % qk == 0);
 233    assert(qk == QK5_0);
 234    assert(nrc == 1);
 235    UNUSED(nrc);
 236    UNUSED(bx);
 237    UNUSED(by);
 238    UNUSED(bs);
 239
 240    const block_q5_0 * GGML_RESTRICT x = vx;
 241    const block_q8_0 * GGML_RESTRICT y = vy;
 242
 243    size_t vl;
 244    size_t vlenb = __riscv_vlenb();
 245
 246    for (; ib < nb; ++ib) {
 247        vl = qk / 2;
 248        vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
 249        vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
 250        vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
 251        vint8m2_t v0c;
 252        if (vlenb == 16) {
 253            v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
 254        } else {
 255            v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
 256            v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
 257        }
 258
 259        vl = qk;
 260        vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
 261        qh = __riscv_vmnand_mm_b4(qh, qh, vl);
 262        vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
 263        vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
 264        vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
 265        vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
 266        vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
 267        int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
 268
 269        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
 270    }
 271
 272    *s = sumf;
 273#else
 274    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 275#endif
 276}
 277
 278void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 279#if defined(__riscv_v)
 280    const int qk = QK8_1;
 281    const int nb = n / qk;
 282
 283    int ib = 0;
 284    float sumf = 0;
 285
 286    assert(n % qk == 0);
 287    assert(qk == QK5_1);
 288    assert(nrc == 1);
 289    UNUSED(nrc);
 290    UNUSED(bx);
 291    UNUSED(by);
 292    UNUSED(bs);
 293
 294    const block_q5_1 * GGML_RESTRICT x = vx;
 295    const block_q8_1 * GGML_RESTRICT y = vy;
 296
 297    size_t vl;
 298    size_t vlenb = __riscv_vlenb();
 299
 300    for (; ib < nb; ++ib) {
 301        vl = qk / 2;
 302        vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
 303        vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
 304        vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
 305        vint8m2_t v0c;
 306        if (vlenb == 16) {
 307            v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
 308        } else {
 309            v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
 310            v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
 311        }
 312
 313        vl = qk;
 314        vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
 315        vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
 316        vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
 317        vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
 318        vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
 319        vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
 320        int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
 321
 322        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
 323    }
 324
 325    *s = sumf;
 326#else
 327    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 328#endif
 329}
 330
 331void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 332    const int qk = QK8_0;
 333    const int nb = n / qk;
 334
 335    assert(n % qk == 0);
 336    assert(nrc == 1);
 337    UNUSED(nrc);
 338    UNUSED(bx);
 339    UNUSED(by);
 340    UNUSED(bs);
 341
 342    const block_q8_0 * GGML_RESTRICT x = vx;
 343    const block_q8_0 * GGML_RESTRICT y = vy;
 344
 345    int ib = 0;
 346    float sumf = 0;
 347
 348#if defined(__riscv_v)
 349    size_t vl = qk;
 350
 351    for (; ib < nb; ++ib) {
 352        // load elements
 353        vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
 354        vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
 355
 356        vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
 357
 358        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
 359        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
 360
 361        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
 362
 363        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
 364    }
 365
 366    *s = sumf;
 367#else
 368
 369    UNUSED(nb);
 370    UNUSED(x);
 371    UNUSED(y);
 372    UNUSED(ib);
 373    UNUSED(sumf);
 374
 375    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 376#endif
 377}
 378
 379void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 380    assert(nrc == 1);
 381    UNUSED(nrc);
 382    UNUSED(bx);
 383    UNUSED(by);
 384    UNUSED(bs);
 385
 386    const block_q2_K * GGML_RESTRICT x = vx;
 387    const block_q8_K * GGML_RESTRICT y = vy;
 388
 389    const int nb = n / QK_K;
 390
 391#if defined __riscv_xtheadvector
 392
 393    float sumf = 0;
 394    uint8_t atmp[16];
 395
 396    for (int i = 0; i < nb; ++i) {
 397        const uint8_t * q2 = x[i].qs;
 398        const  int8_t * q8 = y[i].qs;
 399        const uint8_t * sc = x[i].scales;
 400        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
 401        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
 402        uint8_t *patmp = atmp;
 403        int vsums;
 404        int tmp;
 405        __asm__ __volatile__(
 406            "th.vsetvli zero, %[vl16], e8, m1\n\t"
 407            "th.vmv.v.x v8, zero\n\t"
 408            "th.vlb.v v1, (%[sc])\n\t"
 409            "th.vand.vi v0, v1, 0xF\n\t"
 410            "th.vsrl.vi v1, v1, 4\n\t"
 411            "th.vsb.v v0, (%[scale])\n\t"
 412            "th.vwaddu.vx v16, v1, zero\n\t"
 413            "th.vsetvli zero, %[vl16], e16, m2\n\t"
 414            "th.vlh.v v2, (%[bsums])\n\t"
 415            "th.vwmul.vv v4, v16, v2\n\t"
 416            "th.vsetvli zero, %[vl16], e32, m4\n\t"
 417            "th.vredsum.vs v8, v4, v8\n\t"
 418            "th.vmv.x.s %[vsums], v8"
 419            : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
 420            : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
 421            , [vl16] "r" (16)
 422            : "memory"
 423            , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
 424            , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
 425            , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
 426            , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
 427        );
 428        sumf += dmin * vsums;
 429        int isum = 0;
 430
 431        for (int j = 0; j < QK_K/128; ++j) {
 432            __asm__ __volatile__(
 433                "th.vsetvli zero, %[vl32], e8, m2\n\t"
 434                "th.vlb.v v0, (%[q2])\n\t"
 435                "th.vsrl.vi v2, v0, 2\n\t"
 436                "th.vsrl.vi v4, v0, 4\n\t"
 437                "th.vsrl.vi v6, v0, 6\n\t"
 438                "th.vand.vi v0, v0, 0x3\n\t"
 439                "th.vand.vi v2, v2, 0x3\n\t"
 440                "th.vand.vi v4, v4, 0x3\n\t"
 441                "th.vsetvli zero, %[vl128], e8, m8\n\t"
 442                "th.vlb.v v8, (%[q8])\n\t"
 443                "th.vsetvli zero, %[vl64], e8, m4\n\t"
 444                "th.vwmul.vv v16, v0, v8\n\t"
 445                "th.vwmul.vv v24, v4, v12\n\t"
 446                "th.vsetvli zero, %[vl16], e16, m2\n\t"
 447                "th.vmv.v.x v0, zero\n\t"
 448                "th.vwredsum.vs v10, v16, v0\n\t"
 449                "th.vwredsum.vs v9, v18, v0\n\t"
 450                "th.vwredsum.vs v8, v20, v0\n\t"
 451                "th.vwredsum.vs v7, v22, v0\n\t"
 452                "th.vwredsum.vs v11, v24, v0\n\t"
 453                "th.vwredsum.vs v12, v26, v0\n\t"
 454                "th.vwredsum.vs v13, v28, v0\n\t"
 455                "th.vwredsum.vs v14, v30, v0\n\t"
 456                "li %[tmp], 4\n\t"
 457                "th.vsetvli zero, %[tmp], e32, m1\n\t"
 458                "th.vslideup.vi v10, v9, 1\n\t"
 459                "th.vslideup.vi v8, v7, 1\n\t"
 460                "th.vslideup.vi v11, v12, 1\n\t"
 461                "th.vslideup.vi v13, v14, 1\n\t"
 462                "th.vslideup.vi v10, v8, 2\n\t"
 463                "th.vslideup.vi v11, v13, 2\n\t"
 464                "li %[tmp], 8\n\t"
 465                "th.vsetvli zero, %[tmp], e32, m2\n\t"
 466                "th.vlbu.v v12, (%[scale])\n\t"
 467                "th.vmul.vv v10, v10, v12\n\t"
 468                "th.vredsum.vs v0, v10, v0\n\t"
 469                "th.vmv.x.s %[tmp], v0\n\t"
 470                "add %[isum], %[isum], %[tmp]"
 471                : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
 472                : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
 473                , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
 474                : "memory"
 475                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
 476                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
 477                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
 478                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
 479            );
 480            q2 += 32; q8 += 128; patmp += 8;
 481        }
 482
 483        sumf += dall * isum;
 484    }
 485
 486    *s = sumf;
 487
 488#elif defined __riscv_v
 489
 490    float sumf = 0;
 491    uint8_t atmp[16];
 492
 493    const int vector_length = __riscv_vlenb() * 8;
 494    uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 495                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
 496
 497    switch (vector_length) {
 498    case 256:
 499        for (int i = 0; i < nb; ++i) {
 500            const uint8_t * q2 = x[i].qs;
 501            const int8_t *  q8 = y[i].qs;
 502            const uint8_t * sc = x[i].scales;
 503
 504            const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
 505            const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
 506
 507            size_t vl = 16;
 508
 509            vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
 510            vuint8m1_t aux    = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
 511
 512            vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
 513
 514            vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
 515            vuint8mf2_t mins8    = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
 516            vint16m1_t  mins     = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
 517            vint32m2_t  prod     = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
 518            vint32m1_t  vsums    = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
 519
 520            sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
 521
 522            vl = 32;
 523
 524            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
 525            vuint8m1_t v_b   = __riscv_vle8_v_u8m1(temp_01, vl);
 526
 527            uint8_t is   = 0;
 528            int     isum = 0;
 529
 530            for (int j = 0; j < QK_K / 128; ++j) {
 531                // load Q2
 532                vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
 533
 534                vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
 535                vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
 536                vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
 537                vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
 538
 539                // duplicate scale elements for product
 540                vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
 541                vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
 542                vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
 543                vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
 544
 545                vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
 546                vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
 547                vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
 548                vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
 549
 550                // load Q8
 551                vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
 552                vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
 553                vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
 554                vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
 555
 556                vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
 557                vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
 558                vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
 559                vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
 560
 561                vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
 562                vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
 563
 564                isum += __riscv_vmv_x_s_i32m1_i32(isum1);
 565
 566                q2 += 32;
 567                q8 += 128;
 568                is = 8;
 569            }
 570
 571            sumf += dall * isum;
 572        }
 573        break;
 574    case 128:
 575        for (int i = 0; i < nb; ++i) {
 576            const uint8_t * q2 = x[i].qs;
 577            const  int8_t * q8 = y[i].qs;
 578            const uint8_t * sc = x[i].scales;
 579            const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
 580            const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
 581            uint8_t *patmp = atmp;
 582            int vsums;
 583            int tmp, t1, t2, t3, t4, t5, t6, t7;
 584            __asm__ __volatile__(
 585                "vsetivli zero, 16, e8, m1\n\t"
 586                "vmv.v.x v8, zero\n\t"
 587                "lb zero, 15(%[sc])\n\t"
 588                "vle8.v v1, (%[sc])\n\t"
 589                "vle8.v v2, (%[bsums])\n\t"
 590                "addi %[tmp], %[bsums], 16\n\t"
 591                "vand.vi v0, v1, 0xF\n\t"
 592                "vsrl.vi v1, v1, 4\n\t"
 593                "vle8.v v3, (%[tmp])\n\t"
 594                "vse8.v v0, (%[scale])\n\t"
 595                "vsetivli zero, 16, e16, m2\n\t"
 596                "vzext.vf2 v0, v1\n\t"
 597                "vwmul.vv v4, v0, v2\n\t"
 598                "vsetivli zero, 16, e32, m4\n\t"
 599                "vredsum.vs v8, v4, v8\n\t"
 600                "vmv.x.s %[vsums], v8"
 601                : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
 602                : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
 603                : "memory"
 604                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
 605                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
 606                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
 607                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
 608            );
 609            sumf += dmin * vsums;
 610            int isum = 0;
 611
 612            for (int j = 0; j < QK_K/128; ++j) {
 613                __asm__ __volatile__(
 614                    "lb zero, 31(%[q2])\n\t"
 615                    "addi %[tmp], %[q2], 16\n\t"
 616                    "addi %[t1], %[q8], 16\n\t"
 617                    "vsetivli zero, 16, e8, m1\n\t"
 618                    "vle8.v v0, (%[q2])\n\t"
 619                    "vle8.v v1, (%[tmp])\n\t"
 620                    "vsrl.vi v2, v0, 2\n\t"
 621                    "vsrl.vi v3, v1, 2\n\t"
 622                    "vsrl.vi v4, v0, 4\n\t"
 623                    "addi %[tmp], %[q8], 32\n\t"
 624                    "vle8.v v8, (%[q8])\n\t"
 625                    "vle8.v v9, (%[t1])\n\t"
 626                    "addi %[t1], %[t1], 32\n\t"
 627                    "vsrl.vi v5, v1, 4\n\t"
 628                    "vsrl.vi v6, v0, 6\n\t"
 629                    "vsrl.vi v7, v1, 6\n\t"
 630                    "vle8.v v10, (%[tmp])\n\t"
 631                    "vle8.v v11, (%[t1])\n\t"
 632                    "addi %[tmp], %[tmp], 32\n\t"
 633                    "addi %[t1], %[t1], 32\n\t"
 634                    "vand.vi v0, v0, 0x3\n\t"
 635                    "vand.vi v1, v1, 0x3\n\t"
 636                    "vand.vi v2, v2, 0x3\n\t"
 637                    "vle8.v v12, (%[tmp])\n\t"
 638                    "vle8.v v13, (%[t1])\n\t"
 639                    "addi %[tmp], %[tmp], 32\n\t"
 640                    "addi %[t1], %[t1], 32\n\t"
 641                    "vand.vi v3, v3, 0x3\n\t"
 642                    "vand.vi v4, v4, 0x3\n\t"
 643                    "vand.vi v5, v5, 0x3\n\t"
 644                    "vle8.v v14, (%[tmp])\n\t"
 645                    "vle8.v v15, (%[t1])\n\t"
 646                    "vwmul.vv v16, v0, v8\n\t"
 647                    "vwmul.vv v18, v1, v9\n\t"
 648                    "vwmul.vv v20, v2, v10\n\t"
 649                    "vwmul.vv v22, v3, v11\n\t"
 650                    "vwmul.vv v24, v4, v12\n\t"
 651                    "vwmul.vv v26, v5, v13\n\t"
 652                    "vwmul.vv v28, v6, v14\n\t"
 653                    "vwmul.vv v30, v7, v15\n\t"
 654                    "vsetivli zero, 8, e16, m1\n\t"
 655                    "vmv.v.x v0, zero\n\t"
 656                    "lbu %[tmp], 0(%[scale])\n\t"
 657                    "vwredsum.vs v8, v16, v0\n\t"
 658                    "vwredsum.vs v9, v18, v0\n\t"
 659                    "lbu %[t1], 1(%[scale])\n\t"
 660                    "vwredsum.vs v10, v20, v0\n\t"
 661                    "vwredsum.vs v11, v22, v0\n\t"
 662                    "lbu %[t2], 2(%[scale])\n\t"
 663                    "vwredsum.vs v12, v24, v0\n\t"
 664                    "vwredsum.vs v13, v26, v0\n\t"
 665                    "lbu %[t3], 3(%[scale])\n\t"
 666                    "vwredsum.vs v14, v28, v0\n\t"
 667                    "vwredsum.vs v15, v30, v0\n\t"
 668                    "lbu %[t4], 4(%[scale])\n\t"
 669                    "vwredsum.vs v8, v17, v8\n\t"
 670                    "vwredsum.vs v9, v19, v9\n\t"
 671                    "lbu %[t5], 5(%[scale])\n\t"
 672                    "vwredsum.vs v10, v21, v10\n\t"
 673                    "vwredsum.vs v11, v23, v11\n\t"
 674                    "lbu %[t6], 6(%[scale])\n\t"
 675                    "vwredsum.vs v12, v25, v12\n\t"
 676                    "vwredsum.vs v13, v27, v13\n\t"
 677                    "lbu %[t7], 7(%[scale])\n\t"
 678                    "vwredsum.vs v14, v29, v14\n\t"
 679                    "vwredsum.vs v15, v31, v15\n\t"
 680                    "vsetivli zero, 4, e32, m1\n\t"
 681                    "vmul.vx v0, v8, %[tmp]\n\t"
 682                    "vmul.vx v1, v9, %[t1]\n\t"
 683                    "vmacc.vx v0, %[t2], v10\n\t"
 684                    "vmacc.vx v1, %[t3], v11\n\t"
 685                    "vmacc.vx v0, %[t4], v12\n\t"
 686                    "vmacc.vx v1, %[t5], v13\n\t"
 687                    "vmacc.vx v0, %[t6], v14\n\t"
 688                    "vmacc.vx v1, %[t7], v15\n\t"
 689                    "vmv.x.s %[tmp], v0\n\t"
 690                    "vmv.x.s %[t1], v1\n\t"
 691                    "add %[isum], %[isum], %[tmp]\n\t"
 692                    "add %[isum], %[isum], %[t1]"
 693                    : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
 694                    , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
 695                    , [isum] "+&r" (isum)
 696                    : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
 697                    : "memory"
 698                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
 699                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
 700                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
 701                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
 702                );
 703                q2 += 32; q8 += 128; patmp += 8;
 704            }
 705
 706            sumf += dall * isum;
 707        }
 708        break;
 709    default:
 710        assert(false && "Unsupported vector length");
 711        break;
 712    }
 713
 714    *s = sumf;
 715
 716#else
 717
 718    UNUSED(x);
 719    UNUSED(y);
 720    UNUSED(nb);
 721
 722    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 723#endif
 724}
 725
 726void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 727    assert(n % QK_K == 0);
 728    assert(nrc == 1);
 729    UNUSED(nrc);
 730    UNUSED(bx);
 731    UNUSED(by);
 732    UNUSED(bs);
 733
 734    const uint32_t kmask1 = 0x03030303;
 735    const uint32_t kmask2 = 0x0f0f0f0f;
 736
 737    const block_q3_K * GGML_RESTRICT x = vx;
 738    const block_q8_K * GGML_RESTRICT y = vy;
 739
 740    const int nb = n / QK_K;
 741
 742#if defined __riscv_xtheadvector
 743
 744    uint32_t utmp[4];
 745    float sumf = 0;
 746
 747    for (int i = 0; i < nb; ++i) {
 748        const uint8_t * restrict q3 = x[i].qs;
 749        const uint8_t * restrict qh = x[i].hmask;
 750        const  int8_t * restrict q8 = y[i].qs;
 751
 752        int8_t * scale = (int8_t *)utmp;
 753        int tmp;
 754        __asm__ __volatile__(
 755            "li %[tmp], 12\n\t"
 756            "th.vsetvli zero, %[tmp], e8, m1\n\t"
 757            "th.vlb.v v0, (%[s6b])\n\t"
 758            "th.vmv.v.v v2, v0\n\t"
 759            "li %[tmp], 2\n\t"
 760            "th.vsetvli zero, %[tmp], e64, m1\n\t"
 761            "th.vmv.v.x v9, %[sh]\n\t"\
 762            "th.vslidedown.vi v1, v0, 1\n\t"
 763            "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
 764            "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
 765            "li %[tmp], 4\n\t"
 766            "th.vsetvli zero, %[tmp], e32, m1\n\t"
 767            "th.vid.v v9\n\t"
 768            "th.vmv.x.s %[tmp], v1\n\t"
 769            "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
 770            "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
 771            "th.vsrl.vv v4, v1, v9\n\t"
 772            "th.vsrl.vv v2, v0, v8\n\t"
 773            "th.vand.vx v5, v4, %[kmask1]\n\t"
 774            "th.vand.vx v3, v2, %[kmask2]\n\t"
 775            "th.vsll.vi v6, v5, 4\n\t"
 776            "th.vor.vv v7, v6, v3\n\t"
 777            "li %[tmp], 16\n\t"
 778            "th.vsetvli zero, %[tmp], e8, m1\n\t"
 779            "th.vsub.vx v0, v7, %[c]\n\t"
 780            "th.vsb.v v0, (%[scale])"
 781            : [tmp] "=&r" (tmp)
 782            : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
 783            , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
 784            : "memory"
 785            , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
 786            , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
 787            , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
 788            , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
 789        );
 790
 791        uint8_t m = 1;
 792        int isum = 0;
 793        for (int j = 0; j < QK_K; j += 128) {
 794            __asm__ __volatile__(
 795                // fixme: use v0p7 mask layout directly
 796                "th.vsetvli zero, %[vl32], e8, m2\n\t"
 797                "th.vlb.v v8, (%[q3])\n\t"
 798                "th.vsrl.vi v10, v8, 2\n\t"
 799                "th.vsrl.vi v12, v8, 4\n\t"
 800                "th.vsrl.vi v14, v8, 6\n\t"
 801                "th.vand.vi v8, v8, 3\n\t"
 802                "th.vand.vi v10, v10, 3\n\t"
 803                "th.vand.vi v12, v12, 3\n\t"
 804                "th.vlb.v v2, (%[qh])\n\t"
 805                "th.vand.vx v4, v2, %[m]\n\t"
 806                "slli %[m], %[m], 1\n\t"
 807                "th.vmseq.vx v0, v4, zero\n\t"
 808                "th.vadd.vi v8, v8, -4, v0.t\n\t"
 809                "th.vand.vx v4, v2, %[m]\n\t"
 810                "slli %[m], %[m], 1\n\t"
 811                "th.vmseq.vx v0, v4, zero\n\t"
 812                "th.vadd.vi v10, v10, -4, v0.t\n\t"
 813                "th.vand.vx v4, v2, %[m]\n\t"
 814                "slli %[m], %[m], 1\n\t"
 815                "th.vmseq.vx v0, v4, zero\n\t"
 816                "th.vadd.vi v12, v12, -4, v0.t\n\t"
 817                "th.vand.vx v4, v2, %[m]\n\t"
 818                "slli %[m], %[m], 1\n\t"
 819                "th.vmseq.vx v0, v4, zero\n\t"
 820                "th.vadd.vi v14, v14, -4, v0.t\n\t"
 821                "th.vsetvli zero, %[vl128], e8, m8\n\t"
 822                "th.vlb.v v0, (%[q8])\n\t"
 823                "th.vsetvli zero, %[vl64], e8, m4\n\t"
 824                "th.vwmul.vv v16, v0, v8\n\t"
 825                "th.vwmul.vv v24, v4, v12\n\t"
 826                "li %[tmp], 16\n\t"
 827                "th.vsetvli zero, %[tmp], e16, m2\n\t"
 828                "th.vmv.v.x v0, zero\n\t"
 829                "th.vwredsum.vs v10, v16, v0\n\t"
 830                "th.vwredsum.vs v9, v18, v0\n\t"
 831                "th.vwredsum.vs v8, v20, v0\n\t"
 832                "th.vwredsum.vs v7, v22, v0\n\t"
 833                "th.vwredsum.vs v11, v24, v0\n\t"
 834                "th.vwredsum.vs v12, v26, v0\n\t"
 835                "th.vwredsum.vs v13, v28, v0\n\t"
 836                "th.vwredsum.vs v14, v30, v0\n\t"
 837                "li %[tmp], 4\n\t"
 838                "th.vsetvli zero, %[tmp], e32, m1\n\t"
 839                "th.vslideup.vi v10, v9, 1\n\t"
 840                "th.vslideup.vi v8, v7, 1\n\t"
 841                "th.vslideup.vi v11, v12, 1\n\t"
 842                "th.vslideup.vi v13, v14, 1\n\t"
 843                "th.vslideup.vi v10, v8, 2\n\t"
 844                "th.vslideup.vi v11, v13, 2\n\t"
 845                "li %[tmp], 8\n\t"
 846                "th.vsetvli zero, %[tmp], e32, m2\n\t"
 847                "th.vlb.v v12, (%[scale])\n\t"
 848                "th.vmul.vv v10, v10, v12\n\t"
 849                "th.vredsum.vs v0, v10, v0\n\t"
 850                "th.vmv.x.s %[tmp], v0\n\t"
 851                "add %[isum], %[isum], %[tmp]"
 852                : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
 853                : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
 854                , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
 855                : "memory"
 856                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
 857                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
 858                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
 859                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
 860            );
 861            q3 += 32;    q8 += 128;   scale += 8;
 862        }
 863
 864        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
 865        sumf += d * isum;
 866    }
 867
 868    *s = sumf;
 869
 870#elif defined __riscv_v
 871
 872    uint32_t utmp[4];
 873    float sumf = 0;
 874    uint32_t aux[3];
 875    const int vector_length = __riscv_vlenb() * 8;
 876
 877    switch (vector_length) {
 878    case 256:
 879        for (int i = 0; i < nb; ++i) {
 880
 881            const uint8_t * GGML_RESTRICT q3 = x[i].qs;
 882            const uint8_t * GGML_RESTRICT qh = x[i].hmask;
 883            const  int8_t * GGML_RESTRICT q8 = y[i].qs;
 884
 885            memcpy(aux, x[i].scales, 12);
 886            utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
 887            utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
 888            utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
 889            utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
 890
 891            int8_t * scale = (int8_t *)utmp;
 892            for (int j = 0; j < 16; ++j) scale[j] -= 32;
 893
 894
 895            size_t vl = 32;
 896            uint8_t m =  1;
 897
 898            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
 899            vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
 900
 901            int sum_t = 0;
 902
 903            for (int j = 0; j < QK_K; j += 128) {
 904
 905                vl = 32;
 906
 907                // load Q3
 908                vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
 909
 910                vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
 911                vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
 912                vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
 913                vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
 914
 915                // compute mask for subtraction
 916                vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
 917                vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
 918                vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
 919                m <<= 1;
 920
 921                vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
 922                vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
 923                vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
 924                m <<= 1;
 925
 926                vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
 927                vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
 928                vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
 929                m <<= 1;
 930
 931                vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
 932                vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
 933                vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
 934                m <<= 1;
 935
 936                // load Q8 and take product with Q3
 937                vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
 938                vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
 939                vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
 940                vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
 941
 942                vl = 16;
 943
 944                // retrieve lane to multiply with scale
 945                vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
 946                vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
 947                vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
 948                vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
 949                vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
 950                vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
 951                vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
 952                vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
 953
 954                vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
 955                vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
 956                vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
 957                vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
 958
 959                sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
 960
 961                q3 += 32;    q8 += 128;   scale += 8;
 962
 963            }
 964
 965            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
 966
 967            sumf += d*sum_t;
 968
 969        }
 970        break;
 971    case 128:
 972        for (int i = 0; i < nb; ++i) {
 973            const uint8_t * restrict q3 = x[i].qs;
 974            const uint8_t * restrict qh = x[i].hmask;
 975            const  int8_t * restrict q8 = y[i].qs;
 976
 977            int8_t * scale = (int8_t *)utmp;
 978            int tmp, t1, t2, t3, t4, t5, t6, t7;
 979            __asm__ __volatile__(
 980                "vsetivli zero, 12, e8, m1\n\t"
 981                "vle8.v v0, (%[s6b])\n\t"
 982                "vmv1r.v v2, v0\n\t"
 983                "vsetivli zero, 2, e64, m1\n\t"
 984                "vmv.v.x v9, %[sh]\n\t"\
 985                "vslidedown.vi v1, v0, 1\n\t"
 986                "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
 987                "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
 988                "vsetivli zero, 4, e32, m1\n\t"
 989                "vid.v v9\n\t"
 990                "vmv.x.s %[tmp], v1\n\t"
 991                "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
 992                "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
 993                "vsrl.vv v4, v1, v9\n\t"
 994                "vsrl.vv v2, v0, v8\n\t"
 995                "vand.vx v5, v4, %[kmask1]\n\t"
 996                "vand.vx v3, v2, %[kmask2]\n\t"
 997                "vsll.vi v6, v5, 4\n\t"
 998                "vor.vv v7, v6, v3\n\t"
 999                "vsetivli zero, 16, e8, m1\n\t"
1000                "vsub.vx v0, v7, %[c]\n\t"
1001                "vse8.v v0, (%[scale])"
1002                : [tmp] "=&r" (tmp)
1003                : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
1004                , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
1005                : "memory"
1006                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1007                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1008                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1009                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1010            );
1011
1012            uint8_t m = 1;
1013            int isum = 0;
1014            for (int j = 0; j < QK_K; j += 128) {
1015                __asm__ __volatile__(
1016                    "lb zero, 31(%[q3])\n\t"
1017                    "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
1018                    "vle8.v v8, (%[q3])\n\t"
1019                    "vsrl.vi v10, v8, 2\n\t"
1020                    "vsrl.vi v12, v8, 4\n\t"
1021                    "vsrl.vi v14, v8, 6\n\t"
1022                    "lb zero, 64(%[q8])\n\t"
1023                    "vand.vi v8, v8, 3\n\t"
1024                    "vand.vi v10, v10, 3\n\t"
1025                    "vand.vi v12, v12, 3\n\t"
1026                    "vle8.v v2, (%[qh])\n\t"
1027                    "lb zero, 127(%[q8])\n\t"
1028                    "vand.vx v4, v2, %[m]\n\t"
1029                    "slli %[m], %[m], 1\n\t"
1030                    "vmseq.vx v0, v4, zero\n\t"
1031                    "vadd.vi v8, v8, -4, v0.t\n\t"
1032                    "lb zero, 0(%[q8])\n\t"
1033                    "vand.vx v4, v2, %[m]\n\t"
1034                    "slli %[m], %[m], 1\n\t"
1035                    "vmseq.vx v0, v4, zero\n\t"
1036                    "vadd.vi v10, v10, -4, v0.t\n\t"
1037                    "vand.vx v4, v2, %[m]\n\t"
1038                    "slli %[m], %[m], 1\n\t"
1039                    "vmseq.vx v0, v4, zero\n\t"
1040                    "vadd.vi v12, v12, -4, v0.t\n\t"
1041                    "vand.vx v4, v2, %[m]\n\t"
1042                    "slli %[m], %[m], 1\n\t"
1043                    "vmseq.vx v0, v4, zero\n\t"
1044                    "vadd.vi v14, v14, -4, v0.t\n\t"
1045                    "vsetvli zero, %[vl128], e8, m8\n\t"
1046                    "vle8.v v0, (%[q8])\n\t"
1047                    "lb %[tmp], 0(%[scale])\n\t"
1048                    "lb %[t1], 1(%[scale])\n\t"
1049                    "lb %[t2], 2(%[scale])\n\t"
1050                    "lb %[t3], 3(%[scale])\n\t"
1051                    "vsetvli zero, %[vl64], e8, m4\n\t"
1052                    "vwmul.vv v16, v0, v8\n\t"
1053                    "vwmul.vv v24, v4, v12\n\t"
1054                    "vsetivli zero, 16, e16, m2\n\t"
1055                    "vmv.v.x v0, zero\n\t"
1056                    "vwredsum.vs v8, v16, v0\n\t"
1057                    "lb %[t4], 4(%[scale])\n\t"
1058                    "lb %[t5], 5(%[scale])\n\t"
1059                    "vwredsum.vs v9, v18, v0\n\t"
1060                    "vwredsum.vs v10, v20, v0\n\t"
1061                    "vwredsum.vs v11, v22, v0\n\t"
1062                    "vwredsum.vs v12, v24, v0\n\t"
1063                    "lb %[t6], 6(%[scale])\n\t"
1064                    "lb %[t7], 7(%[scale])\n\t"
1065                    "vwredsum.vs v13, v26, v0\n\t"
1066                    "vwredsum.vs v14, v28, v0\n\t"
1067                    "vwredsum.vs v15, v30, v0\n\t"
1068                    "vsetivli zero, 4, e32, m1\n\t"
1069                    "vmul.vx v0, v8, %[tmp]\n\t"
1070                    "vmul.vx v1, v9, %[t1]\n\t"
1071                    "vmacc.vx v0, %[t2], v10\n\t"
1072                    "vmacc.vx v1, %[t3], v11\n\t"
1073                    "vmacc.vx v0, %[t4], v12\n\t"
1074                    "vmacc.vx v1, %[t5], v13\n\t"
1075                    "vmacc.vx v0, %[t6], v14\n\t"
1076                    "vmacc.vx v1, %[t7], v15\n\t"
1077                    "vmv.x.s %[tmp], v0\n\t"
1078                    "vmv.x.s %[t1], v1\n\t"
1079                    "add %[isum], %[isum], %[tmp]\n\t"
1080                    "add %[isum], %[isum], %[t1]"
1081                    : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
1082                    , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
1083                    , [m] "+&r" (m), [isum] "+&r" (isum)
1084                    : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
1085                    , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
1086                    : "memory"
1087                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1088                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1089                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1090                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1091                );
1092                q3 += 32;    q8 += 128;   scale += 8;
1093            }
1094
1095            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1096            sumf += d * isum;
1097        }
1098        break;
1099    default:
1100        assert(false && "Unsupported vector length");
1101        break;
1102    }
1103
1104    *s = sumf;
1105
1106#else
1107
1108    UNUSED(kmask1);
1109    UNUSED(kmask2);
1110    UNUSED(x);
1111    UNUSED(y);
1112    UNUSED(nb);
1113
1114    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1115#endif
1116
1117}
1118
1119void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1120    assert(n % QK_K == 0);
1121    assert(nrc == 1);
1122    UNUSED(nrc);
1123    UNUSED(bx);
1124    UNUSED(by);
1125    UNUSED(bs);
1126
1127    const block_q4_K * GGML_RESTRICT x = vx;
1128    const block_q8_K * GGML_RESTRICT y = vy;
1129
1130    const int nb = n / QK_K;
1131
1132    static const uint32_t kmask1 = 0x3f3f3f3f;
1133    static const uint32_t kmask2 = 0x0f0f0f0f;
1134    static const uint32_t kmask3 = 0x03030303;
1135
1136    uint32_t utmp[4];
1137
1138#if defined __riscv_xtheadvector
1139
1140    const uint8_t * scales = (const uint8_t*)&utmp[0];
1141    const uint8_t * mins   = (const uint8_t*)&utmp[2];
1142
1143    float sumf = 0;
1144
1145    for (int i = 0; i < nb; ++i) {
1146        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1147        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1148
1149        int tmp, tmp2, sumi;
1150        __asm__ __volatile__(
1151            "li %[t1], 12\n\t"
1152            "th.vsetvli zero, %[t1], e8, m1\n\t"
1153            "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1154            "li %[t1], 4\n\t"
1155            "th.vsetvli zero, %[t1], e32, m1\n\t"
1156            "th.vslidedown.vi v2, v1, 2\n\t"
1157            "th.vmv.v.v v3, v2\n\t"
1158            "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1159            "li %[t1], 2\n\t"
1160            "th.vsetvli zero, %[t1], e32, m1\n\t"
1161            "th.vmv.v.i v4, 4\n\t"
1162            "th.vand.vx v8, v1, %[kmask1]\n\t"
1163            "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1164            "th.vsrl.vi v6, v1, 6\n\t"
1165            "th.vsrl.vv v7, v2, v5\n\t"
1166            "th.vand.vx v0, v6, %[kmask3]\n\t"
1167            "th.vand.vx v2, v7, %[kmask2]\n\t"
1168            "th.vsll.vi v6, v0, 4\n\t"
1169            "li %[t2], 8\n\t"
1170            "addi %[t1], %[utmp], 4\n\t"
1171            "th.vor.vv v1, v6, v2\n\t"
1172            "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
1173            "th.vssw.v v1, (%[t1]), %[t2]\n\t"
1174            "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
1175            "th.vlw.v v2, (%[bsums])\n\t"
1176            "th.vsetvli zero, %[t2], e16, m1\n\t"
1177            "th.vnsrl.vi v0, v2, 0\n\t"
1178            "th.vnsrl.vi v1, v2, 16\n\t"
1179            "th.vadd.vv v2, v0, v1\n\t"
1180            "th.vlbu.v v4, (%[mins])\n\t"
1181            "th.vwmul.vv v6, v4, v2\n\t"
1182            "th.vmv.v.x v0, zero\n\t"
1183            "th.vsetvli zero, %[t2], e32, m2\n\t"
1184            "th.vredsum.vs v0, v6, v0\n\t"
1185            "th.vmv.x.s %[sumi], v0"
1186            : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1187            : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1188            , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1189            , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1190            : "memory"
1191            , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1192            , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1193            , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1194            , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1195        );
1196        sumf -= dmin * sumi;
1197
1198        const uint8_t * restrict q4 = x[i].qs;
1199        const int8_t  * restrict q8 = y[i].qs;
1200
1201        sumi = 0;
1202        const uint8_t * scale = scales;
1203
1204        for (int j = 0; j < QK_K/128; ++j) {
1205            int vl128 = 128, vl64 = 64, vl32 = 32;
1206            __asm__ __volatile__(
1207                "th.vsetvli zero, %[vl128], e8, m8\n\t"
1208                "th.vlb.v v8, (%[q8])\n\t"
1209                "th.vsetvli zero, %[vl64], e8, m4\n\t"
1210                "th.vlb.v v0, (%[q4])\n\t"
1211                "th.vsrl.vi v4, v0, 4\n\t"
1212                "th.vand.vi v0, v0, 0xF\n\t"
1213                "th.vsetvli zero, %[vl32], e8, m2\n\t"
1214                "th.vwmul.vv v28, v6, v14\n\t"
1215                "th.vwmul.vv v20, v4, v10\n\t"
1216                "th.vwmul.vv v24, v2, v12\n\t"
1217                "th.vwmul.vv v16, v0, v8\n\t"
1218                "li %[tmp], 4\n\t"
1219                "th.vsetvli zero, %[tmp], e32, m1\n\t"
1220                "th.vlbu.v v1, (%[scale])\n\t"
1221                "th.vmv.v.x v0, zero\n\t"
1222                "th.vsetvli zero, %[vl32], e16, m4\n\t"
1223                "th.vwredsum.vs v6, v24, v0\n\t"
1224                "th.vwredsum.vs v7, v28, v0\n\t"
1225                "th.vwredsum.vs v4, v16, v0\n\t"
1226                "th.vwredsum.vs v5, v20, v0\n\t"
1227                "th.vsetvli zero, %[tmp], e32, m1\n\t"
1228                "th.vslideup.vi v6, v7, 1\n\t"
1229                "th.vslideup.vi v4, v5, 1\n\t"
1230                "th.vslideup.vi v4, v6, 2\n\t"
1231                "th.vmul.vv v8, v4, v1\n\t"
1232                "th.vredsum.vs v0, v8, v0\n\t"
1233                "th.vmv.x.s %[tmp], v0\n\t"
1234                "add %[sumi], %[sumi], %[tmp]"
1235                : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1236                : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1237                , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1238                : "memory"
1239                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1240                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1241                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1242                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1243            );
1244
1245            q4 += 64;    q8 += 128;    scale += 4;
1246        }
1247
1248        sumf += d * sumi;
1249
1250    }
1251
1252    *s = sumf;
1253
1254#elif defined __riscv_v
1255
1256    const uint8_t * scales = (const uint8_t*)&utmp[0];
1257    const uint8_t * mins   = (const uint8_t*)&utmp[2];
1258
1259    float sumf = 0;
1260    const int vector_length = __riscv_vlenb() * 8;
1261
1262    switch (vector_length) {
1263    case 256:
1264        for (int i = 0; i < nb; ++i) {
1265
1266            size_t vl = 8;
1267
1268            const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1269            const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1270
1271            vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
1272            vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
1273            vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
1274
1275            memcpy(utmp, x[i].scales, 12);
1276            utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1277            const uint32_t uaux = utmp[1] & kmask1;
1278            utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1279            utmp[2] = uaux;
1280            utmp[0] &= kmask1;
1281
1282            vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
1283            vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
1284            vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
1285
1286            vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1287            sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1288
1289            const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1290            const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1291
1292            vl = 32;
1293
1294            int32_t sum_1 = 0;
1295            int32_t sum_2 = 0;
1296
1297            vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
1298
1299            for (int j = 0; j < QK_K/64; ++j) {
1300                // load Q4
1301                vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
1302
1303                // load Q8 and multiply it with lower Q4 nibble
1304                vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
1305                vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
1306                vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
1307                vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
1308
1309                sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
1310
1311                // load Q8 and multiply it with upper Q4 nibble
1312                vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
1313                vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
1314                vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
1315                vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
1316
1317                sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
1318
1319                q4 += 32;    q8 += 64;
1320
1321            }
1322
1323            sumf += d*(sum_1 + sum_2);
1324
1325        }
1326        break;
1327    case 128:
1328        for (int i = 0; i < nb; ++i) {
1329            const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1330            const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1331
1332            float ftmp, ft2;
1333            const uint8_t * restrict q40;
1334            const uint8_t * restrict q41;
1335            const uint8_t * restrict q42;
1336            const uint8_t * restrict q43;
1337            const int8_t  * restrict q80;
1338            const int8_t  * restrict q81;
1339            const int8_t  * restrict q82;
1340            const int8_t  * restrict q83;
1341            int s0, s1, s2, s3;
1342
1343            __asm__ __volatile__(
1344                "li %[s1], 8\n\t"
1345                "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1346                "vle32.v v1, (%[s6b])\n\t"
1347                "vslide1down.vx v1, v1, zero\n\t"
1348                "vmv.v.x v16, zero\n\t"
1349                "vslidedown.vi v2, v1, 2\n\t"
1350                "vmv1r.v v3, v2\n\t"
1351                "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1352                "vsetivli zero, 2, e32, m1, ta, ma\n\t"
1353                "vmv.v.i v4, 4\n\t"
1354                "vand.vx v8, v1, %[kmask1]\n\t"
1355                "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1356                "vsrl.vi v6, v1, 6\n\t"
1357                "vsrl.vv v7, v2, v5\n\t"
1358                "vsse32.v v8, (%[utmp]), %[s1]\n\t"
1359                "vand.vx v0, v6, %[kmask3]\n\t"
1360                "vand.vx v2, v7, %[kmask2]\n\t"
1361                "vsll.vi v6, v0, 4\n\t"
1362                "addi %[s0], %[utmp], 4\n\t"
1363                "vor.vv v1, v6, v2\n\t"
1364                "vsse32.v v1, (%[s0]), %[s1]\n\t"
1365                "vsetivli zero, 8, e16, m1, ta, ma\n\t"
1366                "vle32.v v2, (%[bsums])\n\t"
1367                "vnsrl.wi v0, v2, 0\n\t"
1368                "vnsrl.wi v1, v2, 16\n\t"
1369                "vadd.vv v2, v0, v1\n\t"
1370                "vle8.v v3, (%[mins])\n\t"
1371                "vzext.vf2 v4, v3\n\t"
1372                "vwmul.vv v6, v4, v2\n\t"
1373                "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1374                "vredsum.vs v0, v6, v16\n\t"
1375                "vredsum.vs v0, v7, v0\n\t"
1376                "vfcvt.f.x.v v0, v0\n\t"
1377                "vfmv.f.s %[ftmp], v0\n\t"
1378                "vsetivli zero, 16, e8, m1, ta, ma\n\t"
1379                "vle8.v v0, (%[xs])\n\t"
1380                "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
1381                "addi %[q40], %[xs], 64\n\t"
1382                "addi %[q41], %[xs], 16\n\t"
1383                "addi %[q42], %[xs], 32\n\t"
1384                "addi %[q43], %[xs], 48\n\t"
1385                "addi %[q80], %[ys], 64\n\t"
1386                "vle8.v v1, (%[q41])\n\t"
1387                "vle8.v v2, (%[q42])\n\t"
1388                "addi %[q81], %[ys], 16\n\t"
1389                "addi %[q41], %[q41], 64\n\t"
1390                "addi %[q82], %[ys], 32\n\t"
1391                "vle8.v v3, (%[q43])\n\t"
1392                "vle8.v v8, (%[ys])\n\t"
1393                "addi %[q42], %[q42], 64\n\t"
1394                "addi %[q83], %[ys], 48\n\t"
1395                "addi %[q43], %[q43], 64\n\t"
1396                "vsrl.vi v4, v0, 4\n\t"
1397                "vle8.v v9, (%[q81])\n\t"
1398                "vle8.v v10, (%[q82])\n\t"
1399                "vand.vi v0, v0, 0xF\n\t"
1400                "addi %[q81], %[q81], 64\n\t"
1401                "vsrl.vi v5, v1, 4\n\t"
1402                "addi %[q82], %[q82], 64\n\t"
1403                "vle8.v v11, (%[q83])\n\t"
1404                "vle8.v v12, (%[q80])\n\t"
1405                "vand.vi v1, v1, 0xF\n\t"
1406                "addi %[q83], %[q83], 64\n\t"
1407                "vsrl.vi v6, v2, 4\n\t"
1408                "addi %[q80], %[q80], 64\n\t"
1409                "vle8.v v13, (%[q81])\n\t"
1410                "vle8.v v14, (%[q82])\n\t"
1411                "vand.vi v2, v2, 0xF\n\t"
1412                "addi %[q81], %[q81], 64\n\t"
1413                "vsrl.vi v7, v3, 4\n\t"
1414                "addi %[q82], %[q82], 64\n\t"
1415                "vwmul.vv v16, v0, v8\n\t"
1416                "vle8.v v15, (%[q83])\n\t"
1417                "vle8.v v0, (%[q40])\n\t"
1418                "vand.vi v3, v3, 0xF\n\t"
1419                "addi %[q83], %[q83], 64\n\t"
1420                "vwmul.vv v24, v2, v12\n\t"
1421                "vwmul.vv v20, v4, v10\n\t"
1422                "vwmul.vv v28, v6, v14\n\t"
1423                "vwmacc.vv v16, v1, v9\n\t"
1424                "vle8.v v1, (%[q41])\n\t"
1425                "vle8.v v2, (%[q42])\n\t"
1426                "vwmacc.vv v24, v3, v13\n\t"
1427                "vwmacc.vv v20, v5, v11\n\t"
1428                "vwmacc.vv v28, v7, v15\n\t"
1429                "addi %[q40], %[q80], 64\n\t"
1430                "addi %[q41], %[q81], 64\n\t"
1431                "vle8.v v3, (%[q43])\n\t"
1432                "vle8.v v8, (%[q80])\n\t"
1433                "addi %[q42], %[q82], 64\n\t"
1434                "addi %[q43], %[q83], 64\n\t"
1435                "vsrl.vi v4, v0, 4\n\t"
1436                "vle8.v v9, (%[q81])\n\t"
1437                "vle8.v v10, (%[q82])\n\t"
1438                "vand.vi v0, v0, 0xF\n\t"
1439                "vsrl.vi v5, v1, 4\n\t"
1440                "vsrl.vi v7, v3, 4\n\t"
1441                "vand.vi v3, v3, 0xF\n\t"
1442                "vle8.v v11, (%[q83])\n\t"
1443                "vle8.v v12, (%[q40])\n\t"
1444                "vand.vi v1, v1, 0xF\n\t"
1445                "vsrl.vi v6, v2, 4\n\t"
1446                "vand.vi v2, v2, 0xF\n\t"
1447                "vwmul.vv v18, v0, v8\n\t"
1448                "vle8.v v13, (%[q41])\n\t"
1449                "vle8.v v14, (%[q42])\n\t"
1450                "vwmul.vv v26, v2, v12\n\t"
1451                "vwmul.vv v22, v4, v10\n\t"
1452                "vwmul.vv v30, v6, v14\n\t"
1453                "vwmacc.vv v18, v1, v9\n\t"
1454                "vle8.v v15, (%[q43])\n\t"
1455                "vwmacc.vv v26, v3, v13\n\t"
1456                "vwmacc.vv v22, v5, v11\n\t"
1457                "vwmacc.vv v30, v7, v15\n\t"
1458                "vmv.v.x v0, zero\n\t"
1459                "vsetivli zero, 16, e16, m2, ta, ma\n\t"
1460                "vwredsum.vs v4, v16, v0\n\t"
1461                "lbu %[s0], 0(%[scale])\n\t"
1462                "vwredsum.vs v5, v20, v0\n\t"
1463                "lbu %[s1], 1(%[scale])\n\t"
1464                "vwredsum.vs v6, v24, v0\n\t"
1465                "lbu %[s2], 2(%[scale])\n\t"
1466                "vwredsum.vs v7, v28, v0\n\t"
1467                "lbu %[s3], 3(%[scale])\n\t"
1468                "vwredsum.vs v8, v18, v0\n\t"
1469                "lbu %[q40], 4(%[scale])\n\t"
1470                "vwredsum.vs v9, v22, v0\n\t"
1471                "lbu %[q41], 5(%[scale])\n\t"
1472                "vwredsum.vs v10, v26, v0\n\t"
1473                "lbu %[q42], 6(%[scale])\n\t"
1474                "vwredsum.vs v11, v30, v0\n\t"
1475                "lbu %[q43], 7(%[scale])\n\t"
1476                "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1477                "vmul.vx v0, v4, %[s0]\n\t"
1478                "vmul.vx v1, v8, %[q40]\n\t"
1479                "vmacc.vx v0, %[s1], v5\n\t"
1480                "vmacc.vx v1, %[q41], v9\n\t"
1481                "vmacc.vx v0, %[s2], v6\n\t"
1482                "vmacc.vx v1, %[q42], v10\n\t"
1483                "vmacc.vx v0, %[s3], v7\n\t"
1484                "vmacc.vx v1, %[q43], v11\n\t"
1485                "vfcvt.f.x.v v0, v0\n\t"
1486                "vfcvt.f.x.v v1, v1\n\t"
1487                "vfmv.f.s %[ft2], v0\n\t"
1488                "vfmv.f.s %[ftmp], v1\n\t"
1489                "fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
1490                "fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
1491                : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
1492                , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
1493                , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
1494                , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
1495                : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
1496                , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1497                , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
1498                , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1499                : "memory"
1500                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1501                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1502                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1503                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1504            );
1505        }
1506        break;
1507    default:
1508        assert(false && "Unsupported vector length");
1509        break;
1510    }
1511
1512    *s = sumf;
1513
1514#else
1515
1516    UNUSED(x);
1517    UNUSED(y);
1518    UNUSED(kmask1);
1519    UNUSED(kmask2);
1520    UNUSED(kmask3);
1521    UNUSED(nb);
1522    UNUSED(utmp);
1523
1524    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1525#endif
1526}
1527
1528void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
1529    assert(n % QK_K == 0);
1530    assert(nrc == 1);
1531    UNUSED(nrc);
1532    UNUSED(bx);
1533    UNUSED(by);
1534    UNUSED(bs);
1535
1536    const block_q5_K * GGML_RESTRICT x = vx;
1537    const block_q8_K * GGML_RESTRICT y = vy;
1538
1539    const int nb = n / QK_K;
1540
1541    static const uint32_t kmask1 = 0x3f3f3f3f;
1542    static const uint32_t kmask2 = 0x0f0f0f0f;
1543    static const uint32_t kmask3 = 0x03030303;
1544
1545    uint32_t utmp[4];
1546
1547#if defined __riscv_v
1548
1549    const uint8_t * scales = (const uint8_t*)&utmp[0];
1550    const uint8_t * mins   = (const uint8_t*)&utmp[2];
1551
1552    float sumf = 0;
1553    float sums = 0.0;
1554
1555    size_t vl;
1556
1557    for (int i = 0; i < nb; ++i) {
1558
1559        vl = 8;
1560
1561        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1562        const uint8_t * GGML_RESTRICT hm = x[i].qh;
1563        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
1564
1565        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1566        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
1567
1568        vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
1569        vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
1570        vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
1571
1572        memcpy(utmp, x[i].scales, 12);
1573        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1574        const uint32_t uaux = utmp[1] & kmask1;
1575        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1576        utmp[2] = uaux;
1577        utmp[0] &= kmask1;
1578
1579        vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
1580        vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
1581        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
1582
1583        vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1584        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1585
1586        vl = 32;
1587        int32_t aux32 = 0;
1588        int is = 0;
1589
1590        uint8_t m = 1;
1591        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1592        vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
1593
1594        for (int j = 0; j < QK_K/64; ++j) {
1595            // load Q5 and Q8
1596            vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
1597            vint8m2_t  q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
1598            vint8m2_t  q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
1599
1600            // compute mask for addition
1601            vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
1602            vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
1603            vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
1604            vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
1605            m <<= 1;
1606
1607            vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
1608            vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
1609            vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
1610            vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
1611            m <<= 1;
1612
1613            vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
1614            vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
1615
1616            vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
1617            vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
1618
1619            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
1620            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
1621
1622            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
1623            q5 += 32;    q8 += 64;
1624
1625        }
1626
1627        sums += aux32 * d;
1628
1629    }
1630
1631    *s = sumf+sums;
1632
1633#else
1634
1635    UNUSED(x);
1636    UNUSED(y);
1637    UNUSED(kmask1);
1638    UNUSED(kmask2);
1639    UNUSED(kmask3);
1640    UNUSED(nb);
1641    UNUSED(utmp);
1642
1643    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1644#endif
1645}
1646
1647void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1648    assert(n % QK_K == 0);
1649    assert(nrc == 1);
1650    UNUSED(nrc);
1651    UNUSED(bx);
1652    UNUSED(by);
1653    UNUSED(bs);
1654
1655    const block_q6_K * GGML_RESTRICT x = vx;
1656    const block_q8_K * GGML_RESTRICT y = vy;
1657
1658    const int nb = n / QK_K;
1659
1660#if defined __riscv_xtheadvector
1661
1662    float sumf = 0;
1663
1664    for (int i = 0; i < nb; ++i) {
1665
1666        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1667
1668        const uint8_t * restrict q6 = x[i].ql;
1669        const uint8_t * restrict qh = x[i].qh;
1670        const  int8_t * restrict q8 = y[i].qs;
1671
1672        const int8_t * restrict scale = x[i].scales;
1673
1674        int sum_t = 0;
1675        int t0;
1676
1677        for (int j = 0; j < QK_K/128; ++j) {
1678            __asm__ __volatile__(
1679                "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
1680                "th.vlb.v v4, (%[qh])\n\t"
1681                "th.vsll.vi v0, v4, 4\n\t"
1682                "th.vsll.vi v2, v4, 2\n\t"
1683                "th.vsrl.vi v6, v4, 2\n\t"
1684                "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1685                "th.vlb.v v8, (%[q6])\n\t"
1686                "th.vsrl.vi v12, v8, 4\n\t"
1687                "th.vand.vi v8, v8, 0xF\n\t"
1688                "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
1689                "th.vand.vx v0, v0, %[mask]\n\t"
1690                "th.vor.vv v8, v8, v0\n\t"
1691                "th.vlb.v v0, (%[q8])\n\t"
1692                "th.vsub.vx v8, v8, %[vl32]\n\t"
1693                "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1694                "th.vwmul.vv v16, v0, v8\n\t"
1695                "th.vwmul.vv v24, v4, v12\n\t"
1696                "li %[t0], 16\n\t"
1697                "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
1698                "th.vmv.v.x v0, zero\n\t"
1699                "th.vwredsum.vs v10, v16, v0\n\t"
1700                "th.vwredsum.vs v9, v18, v0\n\t"
1701                "th.vwredsum.vs v8, v20, v0\n\t"
1702                "th.vwredsum.vs v7, v22, v0\n\t"
1703                "th.vwredsum.vs v11, v24, v0\n\t"
1704                "th.vwredsum.vs v12, v26, v0\n\t"
1705                "th.vwredsum.vs v13, v28, v0\n\t"
1706                "th.vwredsum.vs v14, v30, v0\n\t"
1707                "li %[t0], 4\n\t"
1708                "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
1709                "th.vslideup.vi v10, v9, 1\n\t"
1710                "th.vslideup.vi v8, v7, 1\n\t"
1711                "th.vslideup.vi v11, v12, 1\n\t"
1712                "th.vslideup.vi v13, v14, 1\n\t"
1713                "th.vslideup.vi v10, v8, 2\n\t"
1714                "th.vslideup.vi v11, v13, 2\n\t"
1715                "li %[t0], 8\n\t"
1716                "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
1717                "th.vlb.v v4, (%[scale])\n\t"
1718                "th.vmul.vv v2, v4, v10\n\t"
1719                "th.vredsum.vs v0, v2, v0\n\t"
1720                "th.vmv.x.s %[t0], v0\n\t"
1721                "add %[sumi], %[sumi], %[t0]"
1722                : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
1723                : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
1724                , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1725                , [mask] "r" (0x30)
1726                : "memory"
1727                , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1728                , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1729                , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1730                , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1731            );
1732            q6 += 64;   qh += 32;   q8 += 128;   scale += 8;
1733        }
1734
1735        sumf += d * sum_t;
1736
1737    }
1738
1739    *s = sumf;
1740
1741#elif defined __riscv_v
1742
1743    float sumf = 0;
1744    const int vector_length = __riscv_vlenb() * 8;
1745
1746    switch (vector_length) {
1747    case 256:
1748        for (int i = 0; i < nb; ++i) {
1749
1750            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1751
1752            const uint8_t * GGML_RESTRICT q6 = x[i].ql;
1753            const uint8_t * GGML_RESTRICT qh = x[i].qh;
1754            const  int8_t * GGML_RESTRICT q8 = y[i].qs;
1755
1756            const int8_t * GGML_RESTRICT scale = x[i].scales;
1757
1758            size_t vl;
1759
1760            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1761
1762            int sum_t = 0;
1763            int is = 0;
1764
1765            for (int j = 0; j < QK_K/128; ++j) {
1766
1767                vl = 32;
1768
1769                // load qh
1770                vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
1771
1772                // load Q6
1773                vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
1774                vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
1775
1776                vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
1777                vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
1778                vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
1779                vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
1780
1781                vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
1782                vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
1783                vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
1784                vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
1785
1786                vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
1787                vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
1788                vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
1789                vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
1790
1791                vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
1792                vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
1793                vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
1794                vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
1795
1796                // load Q8 and take product
1797                vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
1798                vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
1799                vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
1800                vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
1801
1802                vl = 16;
1803
1804                vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
1805                vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
1806                vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
1807                vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
1808                vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
1809                vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
1810                vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
1811                vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
1812
1813                vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
1814                vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
1815                vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
1816                vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
1817
1818                sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1819
1820                q6 += 64;   qh += 32;   q8 += 128;   is=8;
1821
1822            }
1823
1824            sumf += d * sum_t;
1825
1826        }
1827        break;
1828    case 128:
1829        for (int i = 0; i < nb; ++i) {
1830
1831            __builtin_prefetch(&x[i + 1].d, 0, 1);
1832
1833            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1834
1835            const uint8_t * restrict q6 = x[i].ql;
1836            const uint8_t * restrict qh = x[i].qh;
1837            const  int8_t * restrict q8 = y[i].qs;
1838
1839            const int8_t * restrict scale = x[i].scales;
1840
1841            int q6h;
1842            float ftmp;
1843
1844            for (int j = 0; j < QK_K/128; ++j) {
1845                __asm__ __volatile__(
1846                    "addi %[q6h], %[q6], 32\n\t"
1847                    "ld t0, 0(%[scale])\n\t"
1848                    "addi %[scale], %[scale], 8\n\t"
1849                    "slli t6, t0, 1 * 8\n\t"
1850                    "lb zero, 0(%[q6])\n\t"
1851                    "slli t5, t0, 2 * 8\n\t"
1852                    "slli t4, t0, 3 * 8\n\t"
1853                    "lb zero, 0(%[q6h])\n\t"
1854                    "slli t3, t0, 4 * 8\n\t"
1855                    "slli t2, t0, 5 * 8\n\t"
1856                    "lb zero, 0(%[qh])\n\t"
1857                    "lb zero, 31(%[q6h])\n\t"
1858                    "slli t1, t0, 6 * 8\n\t"
1859                    "srai a7, t0, 56\n\t"
1860                    "vsetvli zero, %[vl32], e8, m2\n\t"
1861                    "vle8.v v8, (%[q6])\n\t"
1862                    "srai t6, t6, 56\n\t"
1863                    "srai t5, t5, 56\n\t"
1864                    "srai t4, t4, 56\n\t"
1865                    "srai t3, t3, 56\n\t"
1866                    "vle8.v v10, (%[q6h])\n\t"
1867                    "addi %[q6], %[q6], 64\n\t"
1868                    "slli t0, t0, 7 * 8\n\t"
1869                    "srai t2, t2, 56\n\t"
1870                    "srai t1, t1, 56\n\t"
1871                    "srai t0, t0, 56\n\t"
1872                    "vle8.v v4, (%[qh])\n\t"
1873                    "vsrl.vi v12, v8, 4\n\t"
1874                    "vsrl.vi v14, v10, 4\n\t"
1875                    "lb zero, 0(%[q8])\n\t"
1876                    "vand.vi v8, v8, 0xF\n\t"
1877                    "vand.vi v10, v10, 0xF\n\t"
1878                    "lb zero, 32(%[q8])\n\t"
1879                    "vsll.vi v0, v4, 4\n\t"
1880                    "vsll.vi v2, v4, 2\n\t"
1881                    "lb zero, 64(%[q8])\n\t"
1882                    "vsrl.vi v6, v4, 2\n\t"
1883                    "vand.vx v0, v0, %[mask]\n\t"
1884                    "lb zero, 96(%[q8])\n\t"
1885                    "vand.vx v2, v2, %[mask]\n\t"
1886                    "vand.vx v4, v4, %[mask]\n\t"
1887                    "vand.vx v6, v6, %[mask]\n\t"
1888                    "vor.vv v8, v8, v0\n\t"
1889                    "lb zero, 127(%[q8])\n\t"
1890                    "vor.vv v10, v10, v2\n\t"
1891                    "vor.vv v12, v12, v4\n\t"
1892                    "vor.vv v14, v14, v6\n\t"
1893                    "vsetvli zero, %[vl128], e8, m8\n\t"
1894                    "vle8.v v0, (%[q8])\n\t"
1895                    "vsub.vx v8, v8, %[vl32]\n\t"
1896                    "vsetvli zero, %[vl64], e8, m4\n\t"
1897                    "vwmul.vv v16, v0, v8\n\t"
1898                    "vwmul.vv v24, v4, v12\n\t"
1899                    "vsetivli zero, 16, e16, m2\n\t"
1900                    "vmv.v.x v0, zero\n\t"
1901                    "vwredsum.vs v10, v16, v0\n\t"
1902                    "vwredsum.vs v9, v18, v0\n\t"
1903                    "vwredsum.vs v8, v20, v0\n\t"
1904                    "vwredsum.vs v7, v22, v0\n\t"
1905                    "vwredsum.vs v11, v24, v0\n\t"
1906                    "vwredsum.vs v12, v26, v0\n\t"
1907                    "vwredsum.vs v13, v28, v0\n\t"
1908                    "vwredsum.vs v14, v30, v0\n\t"
1909                    "vsetivli zero, 4, e32, m1\n\t"
1910                    "vmul.vx v0, v10, t0\n\t"
1911                    "vmul.vx v1, v9, t1\n\t"
1912                    "vmacc.vx v0, t2, v8\n\t"
1913                    "vmacc.vx v1, t3, v7\n\t"
1914                    "vmacc.vx v0, t4, v11\n\t"
1915                    "vmacc.vx v1, t5, v12\n\t"
1916                    "vmacc.vx v0, t6, v13\n\t"
1917                    "vmacc.vx v1, a7, v14\n\t"
1918                    "vadd.vv v0, v0, v1\n\t"
1919                    "vfcvt.f.x.v v0, v0\n\t"
1920                    "vfmv.f.s %[ftmp], v0\n\t"
1921                    "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
1922                    : [q6] "+&r" (q6), [q6h] "=&r" (q6h)
1923                    , [scale] "+&r" (scale)
1924                    , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
1925                    : [qh] "r" (qh), [q8] "r" (q8)
1926                    , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1927                    , [mask] "r" (0x30), [d] "f" (d)
1928                    : "memory"
1929                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1930                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1931                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1932                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1933                    , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
1934                    , "a6", "a5", "a4", "a3"
1935                );
1936                qh += 32;   q8 += 128;
1937            }
1938        }
1939        break;
1940    default:
1941        assert(false && "Unsupported vector length");
1942        break;
1943    }
1944
1945    *s = sumf;
1946
1947#else
1948
1949    UNUSED(x);
1950    UNUSED(y);
1951    UNUSED(nb);
1952
1953    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1954#endif
1955}
1956