1#define GGML_COMMON_IMPL_C
   2#include "ggml-common.h"
   3#include "ggml-quants.h"
   4#include "ggml-impl.h"
   5#include "ggml-cpu.h"
   6#include "simd-mappings.h"
   7
   8#include "../../quants.h"
   9#include "../../ggml-cpu-impl.h"
  10
  11#include <math.h>
  12#include <string.h>
  13#include <assert.h>
  14#include <float.h>
  15#include <stdlib.h> // for qsort
  16#include <stdio.h>  // for GGML_ASSERT
  17
  18#define GROUP_MAX_EPS 1e-15f
  19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
  20#define GROUP_MAX_EPS_IQ2_S 1e-8f
  21#define GROUP_MAX_EPS_IQ1_M 1e-7f
  22#define GROUP_MAX_EPS_IQ1_S 1e-12f
  23
  24#define UNUSED GGML_UNUSED
  25
  26#if defined(__loongarch_sx)
  27
  28static __m128i lsx_packs_w(__m128i a, __m128i b) {
  29    __m128i tmp, tmp1;
  30    tmp = __lsx_vsat_w(a, 15);
  31    tmp1 = __lsx_vsat_w(b, 15);
  32    return __lsx_vpickev_h(tmp1, tmp);
  33}
  34
  35static __m128i lsx_packs_h(__m128i a, __m128i b) {
  36    __m128i tmp, tmp1;
  37    tmp = __lsx_vsat_h(a, 7);
  38    tmp1 = __lsx_vsat_h(b, 7);
  39    return __lsx_vpickev_b(tmp1, tmp);
  40}
  41
  42static __m128i lsx_packus_h(__m128i a, __m128i b) {
  43    __m128i tmp, tmp1;
  44    tmp = __lsx_vsat_hu(a, 7);
  45    tmp1 = __lsx_vsat_hu(b, 7);
  46    return __lsx_vpickev_b(tmp1, tmp);
  47}
  48
  49static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
  50    __m128i tmp1, tmp2;
  51    tmp1 = __lsx_vmulwev_h_b(a, b);
  52    tmp2 = __lsx_vmulwod_h_b(a, b);
  53    return __lsx_vsadd_h(tmp1, tmp2);
  54}
  55
  56static __m128i lsx_madd_h(__m128i a, __m128i b) {
  57    __m128i tmp1, tmp2;
  58    tmp1 = __lsx_vmulwev_w_h(a, b);
  59    tmp2 = __lsx_vmulwod_w_h(a, b);
  60    return __lsx_vadd_w(tmp1, tmp2);
  61}
  62
  63static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
  64    v4i32 __ret = {d, c, b, a};
  65    return (__m128i)__ret;
  66}
  67
  68static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
  69    __m128i mask_f, zero, tmp0, tmp2, mask;
  70    int f = 0x8f;
  71    mask_f = __lsx_vreplgr2vr_b(f);
  72    zero = __lsx_vldi(0);
  73    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
  74    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
  75    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
  76    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
  77    return __lsx_vshuf_b(a, zero, tmp2);
  78}
  79
  80static __m128i lsx_hadd_h(__m128i a, __m128i b) {
  81    __m128i tmp1 = __lsx_vpickev_h(b, a);
  82    __m128i tmp2 = __lsx_vpickod_h(b, a);
  83    return __lsx_vadd_h(tmp1, tmp2);
  84}
  85
  86static __m128i lsx_hadd_w(__m128i a, __m128i b) {
  87    __m128i tmp1 = __lsx_vpickev_w(b, a);
  88    __m128i tmp2 = __lsx_vpickod_w(b, a);
  89    return __lsx_vadd_w(tmp1, tmp2);
  90}
  91
  92static __m128 lsx_hadd_s(__m128 a, __m128 b) {
  93    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
  94    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
  95
  96    return __lsx_vfadd_s(tmp1, tmp2);
  97}
  98
  99static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
 100    __m128 res_0 =lsx_hadd_s(a, b);
 101    __m128 res_1 =lsx_hadd_s(c, d);
 102    __m128 res =lsx_hadd_s(res_0, res_1);
 103    res =lsx_hadd_s(res, res);
 104    res =lsx_hadd_s(res, res);
 105
 106    return ((v4f32)res)[0];
 107}
 108
 109// multiply int8_t, add results pairwise twice
 110static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
 111    // Get absolute values of x vectors
 112    const __m128i ax = __lsx_vsigncov_b(x, x);
 113    // Sign the values of the y vectors
 114    const __m128i sy = __lsx_vsigncov_b(x, y);
 115    // Perform multiplication and create 16-bit values
 116    const __m128i dot = lsx_maddubs_h(ax, sy);
 117    const __m128i ones = __lsx_vreplgr2vr_h(1);
 118    return lsx_madd_h(ones, dot);
 119}
 120#endif
 121
 122#if defined(__loongarch_asx)
 123
 124#ifdef __clang__
 125#define VREGS_PREFIX "$vr"
 126#define XREGS_PREFIX "$xr"
 127#else // GCC
 128#define VREGS_PREFIX "$f"
 129#define XREGS_PREFIX "$f"
 130#endif
 131#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
 132// Convert __m128i to __m256i
 133static inline __m256i ____m256i(__m128i in) {
 134    __m256i out = __lasx_xvldi(0);
 135    __asm__ volatile (
 136        ".irp i," __ALL_REGS                "\n\t"
 137        " .ifc %[out], " XREGS_PREFIX"\\i    \n\t"
 138        "  .irp j," __ALL_REGS              "\n\t"
 139        "   .ifc %[in], " VREGS_PREFIX "\\j  \n\t"
 140        "    xvpermi.q $xr\\i, $xr\\j, 0x20  \n\t"
 141        "   .endif                           \n\t"
 142        "  .endr                             \n\t"
 143        " .endif                             \n\t"
 144        ".endr                               \n\t"
 145        : [out] "+f" (out) : [in] "f" (in)
 146    );
 147    return out;
 148}
 149// Convert two __m128i to __m256i
 150static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
 151    __m256i out;
 152    __asm__ volatile (
 153        ".irp i," __ALL_REGS                "\n\t"
 154        " .ifc %[hi], " VREGS_PREFIX "\\i    \n\t"
 155        "  .irp j," __ALL_REGS              "\n\t"
 156        "   .ifc %[lo], " VREGS_PREFIX "\\j  \n\t"
 157        "    xvpermi.q $xr\\i, $xr\\j, 0x20  \n\t"
 158        "   .endif                           \n\t"
 159        "  .endr                             \n\t"
 160        " .endif                             \n\t"
 161        ".endr                               \n\t"
 162        ".ifnc %[out], %[hi]                 \n\t"
 163        ".irp i," __ALL_REGS                "\n\t"
 164        " .ifc %[out], " XREGS_PREFIX "\\i   \n\t"
 165        "  .irp j," __ALL_REGS              "\n\t"
 166        "   .ifc %[hi], " VREGS_PREFIX "\\j  \n\t"
 167        "    xvori.b $xr\\i, $xr\\j, 0       \n\t"
 168        "   .endif                           \n\t"
 169        "  .endr                             \n\t"
 170        " .endif                             \n\t"
 171        ".endr                               \n\t"
 172        ".endif                              \n\t"
 173        : [out] "=f" (out), [hi] "+f" (inhi)
 174        : [lo] "f" (inlo)
 175    );
 176    return out;
 177}
 178// Convert __m256i low part to __m128i
 179static inline __m128i lasx_extracti128_lo(__m256i in) {
 180    __m128i out;
 181    __asm__ volatile (
 182        ".ifnc %[out], %[in]                 \n\t"
 183        ".irp i," __ALL_REGS                "\n\t"
 184        " .ifc %[out], " VREGS_PREFIX "\\i   \n\t"
 185        "  .irp j," __ALL_REGS              "\n\t"
 186        "   .ifc %[in], " XREGS_PREFIX "\\j  \n\t"
 187        "    vori.b $vr\\i, $vr\\j, 0        \n\t"
 188        "   .endif                           \n\t"
 189        "  .endr                             \n\t"
 190        " .endif                             \n\t"
 191        ".endr                               \n\t"
 192        ".endif                              \n\t"
 193        : [out] "=f" (out) : [in] "f" (in)
 194    );
 195    return out;
 196}
 197// Convert __m256i high part to __m128i
 198static inline __m128i lasx_extracti128_hi(__m256i in) {
 199    __m128i out;
 200    __asm__ volatile (
 201        ".irp i," __ALL_REGS                "\n\t"
 202        " .ifc %[out], " VREGS_PREFIX "\\i   \n\t"
 203        "  .irp j," __ALL_REGS              "\n\t"
 204        "   .ifc %[in], " XREGS_PREFIX "\\j  \n\t"
 205        "    xvpermi.q $xr\\i, $xr\\j, 0x11  \n\t"
 206        "   .endif                           \n\t"
 207        "  .endr                             \n\t"
 208        " .endif                             \n\t"
 209        ".endr                               \n\t"
 210        : [out] "=f" (out) : [in] "f" (in)
 211    );
 212    return out;
 213}
 214
 215static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
 216    v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
 217    return (__m256i)__ret;
 218}
 219
 220static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
 221    v4i64 __ret = {d, c, b, a};
 222    return (__m256i)__ret;
 223}
 224
 225static __m256i lasx_insertf128( __m128i x, __m128i y) {
 226    return lasx_set_q(x, y);
 227}
 228
 229static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
 230    __m256i mask_f, zero, tmp0, tmp2, mask;
 231    int f = 0x8f;
 232    mask_f = __lasx_xvreplgr2vr_b(f);
 233    zero = __lasx_xvldi(0);
 234    tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
 235    tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
 236    mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
 237    tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
 238    return __lasx_xvshuf_b(a, zero, tmp2);
 239}
 240
 241static __m256i lasx_extu8_16(__m128i a) {
 242    return __lasx_vext2xv_hu_bu(____m256i(a));
 243}
 244
 245static __m256i lasx_ext8_16(__m128i a) {
 246    return __lasx_vext2xv_h_b(____m256i(a));
 247}
 248
 249static __m256i lasx_ext16_32(__m128i a) {
 250    return __lasx_vext2xv_w_h(____m256i(a));
 251}
 252
 253static __m128i lasx_extracti128( __m256i a, int pos) {
 254    __m128i ret;
 255    if( pos == 0)
 256    {
 257       ret = lasx_extracti128_lo(a);
 258    } else {
 259       ret = lasx_extracti128_hi(a);
 260    }
 261    return ret;
 262}
 263
 264static __m128 lasx_extractf128( __m256 a, int pos) {
 265    __m128 ret;
 266    if( pos == 0)
 267    {
 268       ret = (__m128)lasx_extracti128_lo((__m256i)a);
 269    } else {
 270       ret = (__m128)lasx_extracti128_hi((__m256i)a);
 271    }
 272    return ret;
 273}
 274
 275static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
 276    __m256i tmp1, tmp2;
 277    tmp1 = __lasx_xvmulwev_h_b(a, b);
 278    tmp2 = __lasx_xvmulwod_h_b(a, b);
 279    return __lasx_xvsadd_h(tmp1, tmp2);
 280}
 281
 282static __m256i lasx_madd_h(__m256i a, __m256i b) {
 283    __m256i tmp1, tmp2;
 284    tmp1 = __lasx_xvmulwev_w_h(a, b);
 285    tmp2 = __lasx_xvmulwod_w_h(a, b);
 286    return __lasx_xvadd_w(tmp1, tmp2);
 287}
 288
 289static __m256i lasx_packs_w(__m256i a, __m256i b) {
 290    __m256i tmp, tmp1;
 291    tmp = __lasx_xvsat_w(a, 15);
 292    tmp1 = __lasx_xvsat_w(b, 15);
 293    return __lasx_xvpickev_h(tmp1, tmp);
 294}
 295
 296static __m256i lasx_packs_h(__m256i a, __m256i b) {
 297    __m256i tmp, tmp1;
 298    tmp = __lasx_xvsat_h(a, 7);
 299    tmp1 = __lasx_xvsat_h(b, 7);
 300    return __lasx_xvpickev_b(tmp1, tmp);
 301}
 302
 303static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
 304    __m256i tmp1, tmp2;
 305    tmp1 = __lasx_xvmulwev_h_b(a, b);
 306    tmp2 = __lasx_xvmulwod_h_b(a, b);
 307    return __lasx_xvadd_h(tmp1, tmp2);
 308}
 309
 310static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
 311    switch (b) {
 312        case 0: return __lasx_xvrepl128vei_h(a, 0);
 313        case 1: return __lasx_xvrepl128vei_h(a, 1);
 314        case 2: return __lasx_xvrepl128vei_h(a, 2);
 315        case 3: return __lasx_xvrepl128vei_h(a, 3);
 316        case 4: return __lasx_xvrepl128vei_h(a, 4);
 317        case 5: return __lasx_xvrepl128vei_h(a, 5);
 318        case 6: return __lasx_xvrepl128vei_h(a, 6);
 319        case 7: return __lasx_xvrepl128vei_h(a, 7);
 320        default: __builtin_unreachable();
 321    }
 322}
 323
 324static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
 325    switch (b) {
 326        case 0: return __lasx_xvandi_b(a, 1 << 0);
 327        case 1: return __lasx_xvandi_b(a, 1 << 1);
 328        case 2: return __lasx_xvandi_b(a, 1 << 2);
 329        case 3: return __lasx_xvandi_b(a, 1 << 3);
 330        case 4: return __lasx_xvandi_b(a, 1 << 4);
 331        case 5: return __lasx_xvandi_b(a, 1 << 5);
 332        case 6: return __lasx_xvandi_b(a, 1 << 6);
 333        case 7: return __lasx_xvandi_b(a, 1 << 7);
 334        default: __builtin_unreachable();
 335    }
 336}
 337
 338// horizontally add 8 floats
 339static inline float hsum_float_8(const __m256 x) {
 340    __m128 res = lasx_extractf128(x, 1);
 341    res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
 342    res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
 343    res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
 344    return ((v4f32)res)[0];
 345}
 346
 347// horizontally add 8 int32_t
 348static inline int hsum_i32_8(const __m256i a) {
 349
 350    __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
 351    __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
 352
 353    __m128i  tmp1_128 = lasx_extracti128_lo(tmp1);
 354    __m128i  tmp2_128 = lasx_extracti128_lo(tmp2);
 355
 356    __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
 357
 358    __m128i ev = __lsx_vpickev_w(sum128, sum128);
 359    __m128i od = __lsx_vpickod_w(sum128, sum128);
 360    __m128i sum64 = __lsx_vadd_w(ev, od);
 361
 362    int sum64_1, sum64_2;
 363    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
 364    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
 365
 366    return  sum64_1 + sum64_2;
 367}
 368
 369// horizontally add 4 int32_t
 370static inline int hsum_i32_4(const __m128i a) {
 371    __m128i ev = __lsx_vpickev_w(a, a);
 372    __m128i od = __lsx_vpickod_w(a, a);
 373    __m128i sum64 = __lsx_vadd_w(ev, od);
 374
 375    int sum64_1, sum64_2;
 376    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
 377    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
 378
 379    return  sum64_1 + sum64_2;
 380}
 381
 382// spread 32 bits to 32 bytes { 0x00, 0xFF }
 383static inline __m256i bytes_from_bits_32(const uint8_t * x) {
 384
 385    uint32_t x32;
 386    memcpy(&x32, x, sizeof(uint32_t));
 387    const __m256i shuf_mask = lasx_set_d(
 388            0x0303030303030303, 0x0202020202020202,
 389            0x0101010101010101, 0x0000000000000000);
 390
 391    __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
 392    const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
 393    bytes = __lasx_xvor_v(bytes, bit_mask);
 394    return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
 395}
 396
 397// Unpack 32 4-bit fields into 32 bytes
 398// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
 399static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
 400    const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
 401    __m128i hi = __lsx_vsrli_h(lo, 4);
 402    return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
 403}
 404
 405// add int16_t pairwise and return as float vector
 406static inline __m256 sum_i16_pairs_float(const __m256i x) {
 407    __m256i v = __lasx_xvpackod_h(x, x);
 408    __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
 409    return __lasx_xvffint_s_w(summed_pairs);
 410}
 411
 412static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 413    // Perform multiplication and create 16-bit values
 414    const __m256i dot = lasx_maddubs_h(ax, sy);
 415    return sum_i16_pairs_float(dot);
 416}
 417
 418// multiply int8_t, add results pairwise twice and return as float vector
 419static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 420    const __m256i dot = lasx_madd_h_b(x, y);
 421    return sum_i16_pairs_float(dot);
 422}
 423
 424static inline __m128i packNibbles( __m256i bytes ) {
 425    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
 426    const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
 427     __m256i high = __lasx_xvandn_v(lowByte, bytes);
 428    __m256i low = __lasx_xvand_v(lowByte, bytes);
 429    high = __lasx_xvsrli_h(high, 4);
 430    bytes = __lasx_xvor_v(low, high);
 431    // Compress uint16_t lanes into bytes
 432    __m128i *r0 = (__m128i *)&bytes;
 433    __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
 434    __m128i *r1 = (__m128i *)&tmp_h128;
 435
 436    __m128i zero = __lsx_vldi(0);
 437    __m128i tmp, tmp2, tmp3;
 438
 439    tmp = __lsx_vmax_h(zero, *r0);
 440    tmp2 = __lsx_vsat_hu(tmp, 7);
 441
 442    tmp = __lsx_vmax_h(zero, *r1);
 443    tmp3 = __lsx_vsat_hu(tmp, 7);
 444    return  __lsx_vpickev_b(tmp3, tmp2);
 445}
 446#endif  //__loongarch_asx
 447
 448void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 449    assert(QK8_0 == 32);
 450    assert(k % QK8_0 == 0);
 451    const int nb = k / QK8_0;
 452
 453    block_q8_0 * GGML_RESTRICT y = vy;
 454
 455#if defined(__loongarch_asx)
 456    for (int i = 0; i < nb; i++) {
 457        __m256 v0 = (__m256)__lasx_xvld( x , 0);
 458        __m256 v1 = (__m256)__lasx_xvld( x , 32);
 459        __m256 v2 = (__m256)__lasx_xvld( x , 64);
 460        __m256 v3 = (__m256)__lasx_xvld( x , 96);
 461        x += 32;
 462
 463        // Compute max(abs(e)) for the block
 464        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
 465        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
 466        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
 467        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
 468        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
 469
 470        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
 471        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
 472        __m128 tmp = max4;
 473        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
 474        const float max_scalar = ((v4f32)max4)[0];
 475
 476        // Quantize these floats
 477        const float d = max_scalar / 127.f;
 478        y[i].d = GGML_CPU_FP32_TO_FP16(d);
 479        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
 480        const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
 481
 482        // Apply the multiplier
 483        v0 = __lasx_xvfmul_s( v0, mul );
 484        v1 = __lasx_xvfmul_s( v1, mul );
 485        v2 = __lasx_xvfmul_s( v2, mul );
 486        v3 = __lasx_xvfmul_s( v3, mul );
 487
 488        // Round to nearest integer
 489        __m256i i0 = __lasx_xvftintrne_w_s( v0 );
 490        __m256i i1 = __lasx_xvftintrne_w_s( v1 );
 491        __m256i i2 = __lasx_xvftintrne_w_s( v2 );
 492        __m256i i3 = __lasx_xvftintrne_w_s( v3 );
 493
 494        __m128i ni0 = lasx_extracti128( i0, 0 );
 495        __m128i ni1 = lasx_extracti128( i0, 1);
 496        __m128i ni2 = lasx_extracti128( i1, 0);
 497        __m128i ni3 = lasx_extracti128( i1, 1);
 498        __m128i ni4 = lasx_extracti128( i2, 0);
 499        __m128i ni5 = lasx_extracti128( i2, 1);
 500        __m128i ni6 = lasx_extracti128( i3, 0);
 501        __m128i ni7 = lasx_extracti128( i3, 1);
 502
 503        // Convert int32 to int16
 504        ni0 = lsx_packs_w( ni0, ni1 );
 505        ni2 = lsx_packs_w( ni2, ni3 );
 506        ni4 = lsx_packs_w( ni4, ni5 );
 507        ni6 = lsx_packs_w( ni6, ni7 );
 508        // Convert int16 to int8
 509        ni0 = lsx_packs_h( ni0, ni2 );
 510        ni4 = lsx_packs_h( ni4, ni6 );
 511
 512        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
 513        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
 514
 515    }
 516#else
 517    GGML_UNUSED(nb);
 518    // scalar
 519    quantize_row_q8_0_ref(x, y, k);
 520#endif
 521}
 522
 523void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 524    assert(k % QK8_1 == 0);
 525    const int nb = k / QK8_1;
 526
 527    block_q8_1 * GGML_RESTRICT y = vy;
 528
 529#if defined(__loongarch_asx)
 530    for (int i = 0; i < nb; i++) {
 531        __m256 v0 = (__m256)__lasx_xvld( x , 0 );
 532        __m256 v1 = (__m256)__lasx_xvld( x , 32 );
 533        __m256 v2 = (__m256)__lasx_xvld( x , 64 );
 534        __m256 v3 = (__m256)__lasx_xvld( x , 96 );
 535        x += 32;
 536
 537        // Compute max(abs(e)) for the block
 538        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
 539        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
 540        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
 541        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
 542        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
 543
 544        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
 545        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
 546        __m128 tmp = max4;
 547        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
 548        const float max_scalar = ((v4f32)max4)[0];
 549
 550        // Quantize these floats
 551        const float d = max_scalar / 127.f;
 552        y[i].d = GGML_CPU_FP32_TO_FP16(d);
 553        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
 554        const __m256 mul = __lasx_xvreplfr2vr_s( id );
 555
 556        // Apply the multiplier
 557        v0 = __lasx_xvfmul_s( v0, mul );
 558        v1 = __lasx_xvfmul_s( v1, mul );
 559        v2 = __lasx_xvfmul_s( v2, mul );
 560        v3 = __lasx_xvfmul_s( v3, mul );
 561
 562        // Round to nearest integer
 563        __m256i i0 = __lasx_xvftintrne_w_s( v0 );
 564        __m256i i1 = __lasx_xvftintrne_w_s( v1 );
 565        __m256i i2 = __lasx_xvftintrne_w_s( v2 );
 566        __m256i i3 = __lasx_xvftintrne_w_s( v3 );
 567
 568        __m128i ni0 = lasx_extracti128(i0, 0);
 569        __m128i ni1 = lasx_extracti128( i0, 1);
 570        __m128i ni2 = lasx_extracti128( i1, 0);
 571        __m128i ni3 = lasx_extracti128( i1, 1);
 572        __m128i ni4 = lasx_extracti128( i2, 0 );
 573        __m128i ni5 = lasx_extracti128( i2, 1);
 574        __m128i ni6 = lasx_extracti128( i3, 0);
 575        __m128i ni7 = lasx_extracti128( i3, 1);
 576
 577        // Compute the sum of the quants and set y[i].s
 578        const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));
 579        const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));
 580        y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));
 581
 582        // Convert int32 to int16
 583        ni0 = lsx_packs_w( ni0, ni1 );
 584        ni2 = lsx_packs_w( ni2, ni3 );
 585        ni4 = lsx_packs_w( ni4, ni5 );
 586        ni6 = lsx_packs_w( ni6, ni7 );
 587        // Convert int16 to int8
 588        ni0 = lsx_packs_h( ni0, ni2 );
 589        ni4 = lsx_packs_h( ni4, ni6 );
 590
 591        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
 592        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
 593    }
 594#else
 595    GGML_UNUSED(nb);
 596    // scalar
 597    quantize_row_q8_1_ref(x, y, k);
 598#endif
 599}
 600
 601
 602//===================================== Dot products =================================
 603
 604//
 605// Helper functions
 606//
 607
 608#if defined(__loongarch_asx)
 609// shuffles to pick the required scales in dot products
 610static inline __m256i get_scale_shuffle_q3k(int i) {
 611    static const uint8_t k_shuffle[128] = {
 612         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
 613         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
 614         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
 615        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
 616    };
 617    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
 618}
 619static inline __m256i get_scale_shuffle_k4(int i) {
 620    static const uint8_t k_shuffle[256] = {
 621         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
 622         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
 623         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
 624         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
 625         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
 626        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
 627        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
 628        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
 629    };
 630    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
 631}
 632static inline __m128i get_scale_shuffle(int i) {
 633    static const uint8_t k_shuffle[128] = {
 634         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
 635         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
 636         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
 637         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
 638         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
 639        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
 640        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
 641        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
 642    };
 643    return __lsx_vld((const __m128i*)k_shuffle + i, 0);
 644}
 645#endif
 646
 647void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 648    const int qk = QK8_0;
 649    const int nb = n / qk;
 650
 651    assert(n % qk == 0);
 652    assert(nrc == 1);
 653    UNUSED(nrc);
 654    UNUSED(bx);
 655    UNUSED(by);
 656    UNUSED(bs);
 657
 658    const block_q4_0 * GGML_RESTRICT x = vx;
 659    const block_q8_0 * GGML_RESTRICT y = vy;
 660
 661    int ib = 0;
 662    float sumf = 0;
 663
 664#if defined(__loongarch_asx)
 665    // Initialize accumulator with zeros
 666    __m256 acc = (__m256)__lasx_xvldi(0);
 667
 668    // Main loop
 669    for (; ib < nb; ++ib) {
 670        /* Compute combined scale for the block */
 671        const __m256 d = __lasx_xvreplfr2vr_s( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
 672
 673        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 674
 675        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
 676        const __m256i off = __lasx_xvreplgr2vr_b( 8 );
 677        qx = __lasx_xvsub_b( qx, off );
 678
 679        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
 680
 681        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
 682
 683        /* Multiply q with scale and accumulate */
 684        acc = __lasx_xvfmadd_s( d, q, acc );
 685    }
 686
 687    sumf = hsum_float_8(acc);
 688
 689#elif defined(__loongarch_sx)
 690    // set constants
 691    const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
 692    const __m128i off = __lsx_vreplgr2vr_b(8);
 693
 694    // Initialize accumulator with zeros
 695    __m128 acc_0 = (__m128)__lsx_vldi(0);
 696    __m128 acc_1 = (__m128)__lsx_vldi(0);
 697    __m128 acc_2 = (__m128)__lsx_vldi(0);
 698    __m128 acc_3 = (__m128)__lsx_vldi(0);
 699
 700    for (; ib + 1 < nb; ib += 2) {
 701
 702        // Compute combined scale for the block 0 and 1
 703        const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
 704        const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};
 705
 706        const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
 707
 708        __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
 709        __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
 710        bx_0 = __lsx_vsub_b(bx_0, off);
 711        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
 712
 713        __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
 714        __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
 715        bx_1 = __lsx_vsub_b(bx_1, off);
 716        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
 717
 718        // Compute combined scale for the block 2 and 3
 719        const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);
 720        const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};
 721
 722        const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
 723
 724        __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
 725        __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
 726        bx_2 = __lsx_vsub_b(bx_2, off);
 727        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
 728
 729        __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
 730        __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
 731        bx_3 = __lsx_vsub_b(bx_3, off);
 732        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
 733
 734        // Convert int32_t to float
 735        __m128 p0 = __lsx_vffint_s_w(i32_0);
 736        __m128 p1 = __lsx_vffint_s_w(i32_1);
 737        __m128 p2 = __lsx_vffint_s_w(i32_2);
 738        __m128 p3 = __lsx_vffint_s_w(i32_3);
 739
 740        // Apply the scale
 741        __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );
 742        __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );
 743        __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );
 744        __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );
 745
 746        // Acummulate
 747        acc_0 = __lsx_vfadd_s(p0_d, acc_0);
 748        acc_1 = __lsx_vfadd_s(p1_d, acc_1);
 749        acc_2 = __lsx_vfadd_s(p2_d, acc_2);
 750        acc_3 = __lsx_vfadd_s(p3_d, acc_3);
 751    }
 752
 753    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
 754
 755#endif
 756    for (; ib < nb; ++ib) {
 757        int sumi0 = 0;
 758        int sumi1 = 0;
 759
 760        for (int j = 0; j < qk/2; ++j) {
 761            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
 762            const int v1 = (x[ib].qs[j] >>   4) - 8;
 763
 764            sumi0 += (v0 * y[ib].qs[j]);
 765            sumi1 += (v1 * y[ib].qs[j + qk/2]);
 766        }
 767
 768        int sumi = sumi0 + sumi1;
 769        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
 770    }
 771
 772    *s = sumf;
 773}
 774
 775void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 776    const int qk = QK8_1;
 777    const int nb = n / qk;
 778
 779    assert(n % qk == 0);
 780    assert(nrc == 1);
 781    UNUSED(nrc);
 782    UNUSED(bx);
 783    UNUSED(by);
 784    UNUSED(bs);
 785
 786    const block_q4_1 * GGML_RESTRICT x = vx;
 787    const block_q8_1 * GGML_RESTRICT y = vy;
 788
 789    int ib = 0;
 790    float sumf = 0;
 791
 792#if defined(__loongarch_asx)
 793    // Initialize accumulator with zeros
 794    __m256 acc = (__m256)__lasx_xvldi(0);
 795
 796    float summs = 0;
 797
 798    // Main loop
 799    for (; ib < nb; ++ib) {
 800        const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
 801        const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);
 802
 803        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
 804
 805        const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
 806        const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
 807
 808        // Compute combined scales
 809        const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
 810
 811        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
 812        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 813        const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
 814
 815        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
 816
 817        // Accumulate d0*d1*x*y
 818        acc = __lasx_xvfmadd_s( d0d1, xy, acc );
 819    }
 820
 821    sumf = hsum_float_8(acc) + summs;
 822
 823    *s = sumf;
 824#else
 825    UNUSED(nb);
 826    UNUSED(x);
 827    UNUSED(y);
 828    UNUSED(ib);
 829    UNUSED(sumf);
 830    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 831#endif
 832}
 833
 834void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 835    const int qk = QK8_0;
 836    const int nb = n / qk;
 837
 838    int ib = 0;
 839    float sumf = 0;
 840
 841    assert(n % qk == 0);
 842    assert(qk == QK5_0);
 843    assert(nrc == 1);
 844    UNUSED(nrc);
 845    UNUSED(bx);
 846    UNUSED(by);
 847    UNUSED(bs);
 848
 849    const block_q5_0 * GGML_RESTRICT x = vx;
 850    const block_q8_0 * GGML_RESTRICT y = vy;
 851
 852#if defined(__loongarch_asx)
 853    // Initialize accumulator with zeros
 854    __m256 acc = (__m256)__lasx_xvldi(0);
 855
 856    // Main loop
 857    for (; ib < nb; ++ib) {
 858        /* Compute combined scale for the block */
 859        const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); //FIXME
 860
 861        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 862        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
 863        bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
 864        qx = __lasx_xvor_v(qx, bxhi);
 865
 866        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
 867
 868        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
 869
 870        /* Multiply q with scale and accumulate */
 871        acc = __lasx_xvfmadd_s(d, q, acc);
 872    }
 873
 874    sumf = hsum_float_8(acc);
 875
 876    *s = sumf;
 877#else
 878    UNUSED(nb);
 879    UNUSED(ib);
 880    UNUSED(sumf);
 881    UNUSED(x);
 882    UNUSED(y);
 883    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 884#endif
 885}
 886
 887void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 888    const int qk = QK8_1;
 889    const int nb = n / qk;
 890
 891    int ib = 0;
 892    float sumf = 0;
 893
 894    assert(n % qk == 0);
 895    assert(qk == QK5_1);
 896    assert(nrc == 1);
 897    UNUSED(nrc);
 898    UNUSED(bx);
 899    UNUSED(by);
 900    UNUSED(bs);
 901
 902    const block_q5_1 * GGML_RESTRICT x = vx;
 903    const block_q8_1 * GGML_RESTRICT y = vy;
 904
 905#if defined(__loongarch_asx)
 906    // Initialize accumulator with zeros
 907    __m256 acc = (__m256)__lasx_xvldi(0);
 908
 909    float summs = 0.0f;
 910
 911    // Main loop
 912    for (; ib < nb; ++ib) {
 913        const __m256 dx = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d));
 914
 915        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
 916
 917        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 918        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
 919        bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
 920        qx = __lasx_xvor_v(qx, bxhi);
 921
 922        const __m256 dy = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib].d));
 923        const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
 924
 925        const __m256 q = mul_sum_us8_pairs_float(qx, qy);
 926
 927        acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
 928    }
 929
 930    sumf = hsum_float_8(acc) + summs;
 931
 932    *s = sumf;
 933#else
 934    UNUSED(nb);
 935    UNUSED(ib);
 936    UNUSED(sumf);
 937    UNUSED(x);
 938    UNUSED(y);
 939    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 940#endif
 941}
 942
 943void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 944    const int qk = QK8_0;
 945    const int nb = n / qk;
 946
 947    assert(n % qk == 0);
 948    assert(nrc == 1);
 949    UNUSED(nrc);
 950    UNUSED(bx);
 951    UNUSED(by);
 952    UNUSED(bs);
 953
 954    const block_q8_0 * GGML_RESTRICT x = vx;
 955    const block_q8_0 * GGML_RESTRICT y = vy;
 956
 957    int ib = 0;
 958    float sumf = 0;
 959
 960#if defined(__loongarch_asx)
 961    // Initialize accumulator with zeros
 962    __m256 acc = (__m256)__lasx_xvldi(0);
 963
 964    // Main loop
 965    for (; ib < nb; ++ib) {
 966        // Compute combined scale for the block
 967        const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
 968        __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
 969        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
 970
 971        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
 972
 973        // Multiply q with scale and accumulate
 974        acc = __lasx_xvfmadd_s( d, q, acc );
 975    }
 976
 977    sumf = hsum_float_8(acc);
 978
 979    *s = sumf;
 980#else
 981    UNUSED(nb);
 982    UNUSED(ib);
 983    UNUSED(sumf);
 984    UNUSED(x);
 985    UNUSED(y);
 986    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 987#endif
 988}
 989
 990void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 991    assert(nrc == 1);
 992    UNUSED(nrc);
 993    UNUSED(bx);
 994    UNUSED(by);
 995    UNUSED(bs);
 996
 997    const block_q2_K * GGML_RESTRICT x = vx;
 998    const block_q8_K * GGML_RESTRICT y = vy;
 999
1000    const int nb = n / QK_K;
1001
1002#if defined __loongarch_asx
1003
1004    __m256 acc = (__m256)__lasx_xvldi(0);
1005
1006    for (int i = 0; i < nb; ++i) {
1007
1008        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1009        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1010
1011        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1012        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1013
1014        const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
1015        const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
1016        const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
1017        const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
1018
1019        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
1020
1021        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1022        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1023
1024        __m256i sumi = __lasx_xvldi(0);
1025
1026        for (int j = 0; j < QK_K/128; ++j) {
1027
1028            const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;
1029
1030            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1031            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1032            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1033            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1034
1035            const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
1036            const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
1037            const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
1038            const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
1039
1040            __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
1041            __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
1042            __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
1043            __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
1044
1045            p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
1046            p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
1047            p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
1048            p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
1049
1050            p0 = __lasx_xvadd_w(p0, p1);
1051            p2 = __lasx_xvadd_w(p2, p3);
1052
1053            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));
1054        }
1055
1056        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1057
1058    }
1059
1060    *s = hsum_float_8(acc);
1061
1062#else
1063    UNUSED(x);
1064    UNUSED(y);
1065    UNUSED(nb);
1066    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1067#endif
1068}
1069
1070void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1071    assert(n % QK_K == 0);
1072    assert(nrc == 1);
1073    UNUSED(nrc);
1074    UNUSED(bx);
1075    UNUSED(by);
1076    UNUSED(bs);
1077
1078    const uint32_t kmask1 = 0x03030303;
1079    const uint32_t kmask2 = 0x0f0f0f0f;
1080
1081    const block_q3_K * GGML_RESTRICT x = vx;
1082    const block_q8_K * GGML_RESTRICT y = vy;
1083
1084    const int nb = n / QK_K;
1085
1086#if defined __loongarch_asx
1087
1088    const __m128i m32 = __lsx_vreplgr2vr_b(32);
1089
1090    __m256 acc = (__m256)__lasx_xvldi(0);
1091
1092    uint32_t aux[3];
1093
1094    for (int i = 0; i < nb; ++i) {
1095
1096        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1097        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1098        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1099        // Set up scales
1100        memcpy(aux, x[i].scales, 12);
1101        __m128i scales128 = lsx_set_w(
1102                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1103                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1104                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1105                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1106        scales128 = __lsx_vsub_b(scales128, m32);
1107
1108        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1109        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1110
1111        // high bit
1112        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
1113
1114        // integer accumulator
1115        __m256i sumi = __lasx_xvldi(0);
1116
1117        for (int j = 0; j < QK_K/128; ++j) {
1118            // load low 2 bits
1119            const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
1120
1121            // prepare low and high bits
1122            const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
1123            const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
1124            const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
1125            const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
1126            const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
1127            const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
1128            const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
1129            const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
1130            const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
1131            const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
1132            const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
1133            const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
1134
1135            // load Q8 quants
1136            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1137            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1138            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1139            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1140
1141            __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
1142            __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
1143            __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
1144            __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
1145
1146            // multiply with scales
1147            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
1148            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
1149            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
1150            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
1151
1152            // accumulate
1153            p16_0 = __lasx_xvadd_w(p16_0, p16_1);
1154            p16_2 = __lasx_xvadd_w(p16_2, p16_3);
1155            sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
1156        }
1157        // multiply with block scale and accumulate
1158        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1159    }
1160
1161    *s = hsum_float_8(acc);
1162
1163#else
1164    UNUSED(kmask1);
1165    UNUSED(kmask2);
1166    UNUSED(x);
1167    UNUSED(y);
1168    UNUSED(nb);
1169    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1170#endif
1171}
1172
1173void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1174    assert(n % QK_K == 0);
1175    assert(nrc == 1);
1176    UNUSED(nrc);
1177    UNUSED(bx);
1178    UNUSED(by);
1179    UNUSED(bs);
1180
1181    const block_q4_K * GGML_RESTRICT x = vx;
1182    const block_q8_K * GGML_RESTRICT y = vy;
1183
1184    const int nb = n / QK_K;
1185
1186    static const uint32_t kmask1 = 0x3f3f3f3f;
1187    static const uint32_t kmask2 = 0x0f0f0f0f;
1188    static const uint32_t kmask3 = 0x03030303;
1189
1190    uint32_t utmp[4];
1191
1192#if defined __loongarch_asx
1193
1194    __m256 acc = (__m256)__lasx_xvldi(0);
1195    __m128 acc_m = (__m128)__lsx_vldi(0);
1196
1197   for (int i = 0; i < nb; ++i) {
1198
1199        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1200        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1201
1202        memcpy(utmp, x[i].scales, 12);
1203        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1204        const uint32_t uaux = utmp[1] & kmask1;
1205        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1206        utmp[2] = uaux;
1207        utmp[0] &= kmask1;
1208
1209        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1210        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1211
1212        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
1213        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
1214        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
1215
1216        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
1217        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
1218        const __m128i prod = lsx_madd_h(mins128, q8s);
1219        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
1220
1221        const __m256i scales = lasx_insertf128(scales128, scales128);
1222
1223        __m256i sumi = __lasx_xvldi(0);
1224
1225        for (int j = 0; j < QK_K/64; ++j) {
1226
1227            const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
1228            const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
1229
1230            const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1231            const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
1232            const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
1233
1234            const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1235            __m256i p16l = lasx_madd_h_b(q4l, q8l);
1236            p16l = lasx_madd_h(scale_l, p16l);
1237
1238            const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1239            __m256i p16h = lasx_madd_h_b(q4h, q8h);
1240            p16h = lasx_madd_h(scale_h, p16h);
1241            const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
1242
1243            sumi = __lasx_xvadd_w(sumi, sumj);
1244        }
1245
1246        __m256 vd = __lasx_xvreplfr2vr_s(d);
1247        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
1248
1249    }
1250
1251    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
1252    __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
1253    acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
1254
1255
1256    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
1257
1258#else
1259    UNUSED(x);
1260    UNUSED(y);
1261    UNUSED(nb);
1262    UNUSED(kmask1);
1263    UNUSED(kmask2);
1264    UNUSED(kmask3);
1265    UNUSED(utmp);
1266    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1267#endif
1268}
1269
1270void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
1271    assert(n % QK_K == 0);
1272    assert(nrc == 1);
1273    UNUSED(nrc);
1274    UNUSED(bx);
1275    UNUSED(by);
1276    UNUSED(bs);
1277
1278    const block_q5_K * GGML_RESTRICT x = vx;
1279    const block_q8_K * GGML_RESTRICT y = vy;
1280
1281    const int nb = n / QK_K;
1282
1283    static const uint32_t kmask1 = 0x3f3f3f3f;
1284    static const uint32_t kmask2 = 0x0f0f0f0f;
1285    static const uint32_t kmask3 = 0x03030303;
1286
1287    uint32_t utmp[4];
1288
1289#if defined __loongarch_asx
1290
1291    __m256 acc = (__m256)__lasx_xvldi(0);
1292    __m128 acc_m = (__m128)__lsx_vldi(0);
1293
1294    for (int i = 0; i < nb; ++i) {
1295
1296        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1297        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1298
1299        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1300        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1301
1302        memcpy(utmp, x[i].scales, 12);
1303        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1304        const uint32_t uaux = utmp[1] & kmask1;
1305        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1306        utmp[2] = uaux;
1307        utmp[0] &= kmask1;
1308
1309        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
1310        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
1311        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
1312
1313        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
1314        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
1315        const __m128i prod = lsx_madd_h(mins128, q8s);
1316        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
1317
1318        const __m256i scales = lasx_insertf128(scales128, scales128);
1319
1320        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
1321
1322        __m256i sumi = __lasx_xvldi(0);
1323
1324        for (int j = 0; j < QK_K/64; ++j) {
1325
1326            const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
1327            const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
1328
1329            const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
1330
1331            const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
1332            const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
1333            const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
1334            const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
1335            const __m256i q5_0  = __lasx_xvor_v(q5l_0, q5h_0);
1336            const __m256i q5_1  = __lasx_xvor_v(q5l_1, q5h_1);
1337
1338            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1339            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1340
1341            __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
1342            __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
1343
1344            p16_0 = lasx_madd_h(scale_0, p16_0);
1345            p16_1 = lasx_madd_h(scale_1, p16_1);
1346
1347            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
1348
1349        }
1350
1351        __m256 vd = __lasx_xvreplfr2vr_s(d);
1352        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
1353
1354    }
1355
1356    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
1357    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
1358
1359    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
1360
1361#else
1362    UNUSED(x);
1363    UNUSED(y);
1364    UNUSED(nb);
1365    UNUSED(kmask1);
1366    UNUSED(kmask2);
1367    UNUSED(kmask3);
1368    UNUSED(utmp);
1369    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1370#endif
1371}
1372
1373void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1374    assert(n % QK_K == 0);
1375    assert(nrc == 1);
1376    UNUSED(nrc);
1377    UNUSED(bx);
1378    UNUSED(by);
1379    UNUSED(bs);
1380
1381    const block_q6_K * GGML_RESTRICT x = vx;
1382    const block_q8_K * GGML_RESTRICT y = vy;
1383
1384    const int nb = n / QK_K;
1385
1386#if defined __loongarch_asx
1387
1388    const __m256i m32s = __lasx_xvreplgr2vr_b(32);
1389
1390    __m256 acc = (__m256)__lasx_xvldi(0);
1391
1392    for (int i = 0; i < nb; ++i) {
1393
1394        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1395
1396        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1397        const uint8_t * GGML_RESTRICT qh = x[i].qh;
1398        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1399
1400        const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
1401        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1402        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1403
1404        __m256i sumi = __lasx_xvldi(0);
1405
1406        for (int j = 0; j < QK_K/128; ++j) {
1407
1408            const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1409            const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1410            const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
1411
1412            const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
1413            const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
1414            const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
1415            const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
1416
1417            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
1418            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
1419            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
1420            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
1421
1422            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1423            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1424            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1425            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1426
1427            __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
1428            __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
1429            __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
1430            __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
1431
1432            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
1433            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
1434            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
1435            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
1436
1437            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
1438            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
1439        }
1440
1441        acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1442    }
1443
1444    *s = hsum_float_8(acc);
1445
1446#else
1447    UNUSED(x);
1448    UNUSED(y);
1449    UNUSED(nb);
1450    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1451#endif
1452}
1453
1454#if defined(__loongarch_asx)
1455static const int8_t keven_signs_q2xs[1024] = {
1456     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
1457     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
1458     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
1459     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
1460     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
1461     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
1462     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
1463     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
1464     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
1465     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
1466     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
1467     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
1468     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
1469     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
1470     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
1471     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
1472     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
1473     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
1474     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
1475     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
1476     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
1477     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
1478     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
1479     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
1480     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
1481     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
1482     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
1483     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
1484     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
1485     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
1486     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
1487     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
1488};
1489#endif
1490
1491void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1492    assert(n % QK_K == 0);
1493    assert(nrc == 1);
1494    UNUSED(nrc);
1495    UNUSED(bx);
1496    UNUSED(by);
1497    UNUSED(bs);
1498
1499    const block_iq2_xxs * GGML_RESTRICT x = vx;
1500    const block_q8_K    * GGML_RESTRICT y = vy;
1501
1502    const int nb = n / QK_K;
1503
1504#if defined(__loongarch_asx)
1505
1506    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1507
1508    uint32_t aux32[4];
1509    const uint8_t * aux8 = (const uint8_t *)aux32;
1510
1511    __m256 accumf = (__m256)__lasx_xvldi(0);
1512    for (int i = 0; i < nb; ++i) {
1513        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1514        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1515        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
1516        __m256i sumi1 = __lasx_xvldi(0);
1517        __m256i sumi2 = __lasx_xvldi(0);
1518        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1519            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1520            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1521            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
1522
1523            const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
1524            const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
1525            const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
1526                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
1527            const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
1528                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
1529            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
1530            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
1531            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
1532            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
1533            const uint16_t ls1 = aux32[1] >> 28;
1534            const uint16_t ls2 = aux32[3] >> 28;
1535            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1536            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1537            sumi1 = __lasx_xvadd_w(sumi1, p1);
1538            sumi2 = __lasx_xvadd_w(sumi2, p2);
1539        }
1540
1541        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1542    }
1543
1544    *s = 0.125f * hsum_float_8(accumf);
1545
1546#else
1547    UNUSED(x);
1548    UNUSED(y);
1549    UNUSED(nb);
1550    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1551#endif
1552}
1553
1554void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1555    assert(n % QK_K == 0);
1556    assert(nrc == 1);
1557    UNUSED(nrc);
1558    UNUSED(bx);
1559    UNUSED(by);
1560    UNUSED(bs);
1561
1562    const block_iq2_xs * GGML_RESTRICT x = vx;
1563    const block_q8_K   * GGML_RESTRICT y = vy;
1564
1565    const int nb = n / QK_K;
1566
1567#if defined(__loongarch_asx)
1568
1569    const __m256i mone = __lasx_xvreplgr2vr_b(1);
1570    static const char block_sign_shuffle_mask_1[32] = {
1571        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
1572        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
1573    };
1574    static const char block_sign_shuffle_mask_2[32] = {
1575        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
1576        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
1577    };
1578    static const uint8_t bit_selector_mask_bytes[32] = {
1579        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1580        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1581    };
1582
1583    const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);
1584    const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
1585    const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
1586
1587    static const uint8_t k_bit_helper[32] = {
1588        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
1589        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
1590    };
1591    const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);
1592    const __m256i m511 = __lasx_xvreplgr2vr_h(511);
1593    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
1594    const __m128i m1 = __lsx_vreplgr2vr_b(1);
1595
1596    uint64_t aux64;
1597
1598    // somewhat hacky, but gives a significant boost in performance
1599    __m256i aux_gindex;
1600    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
1601
1602    __m256 accumf = (__m256)__lasx_xvldi(0);
1603    for (int i = 0; i < nb; ++i) {
1604        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1605        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1606        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
1607
1608        memcpy(&aux64, x[i].scales, 8);
1609        __m128i stmp = __lsx_vreplgr2vr_d(aux64);
1610        stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));
1611        const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);
1612
1613        __m256i sumi1 = __lasx_xvldi(0);
1614        __m256i sumi2 = __lasx_xvldi(0);
1615        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
1616
1617            const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0);  q2 += 16;
1618            aux_gindex = __lasx_xvand_v(q2_data, m511);
1619
1620            const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);
1621            const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);
1622            const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);
1623
1624            const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
1625            const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);
1626
1627            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1628            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1629            const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1630            const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1631
1632            const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
1633                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
1634            const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
1635                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
1636            const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
1637                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
1638            const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
1639                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
1640
1641            const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
1642            const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
1643            const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
1644            const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
1645
1646            __m256i signs;
1647            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
1648            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1649            const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
1650
1651            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);
1652            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1653            const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
1654
1655            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);
1656            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1657            const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);
1658
1659            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);
1660            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1661            const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);
1662
1663            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
1664            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
1665            const __m256i dot3  = lasx_maddubs_h(q2_3, q8s_3);
1666            const __m256i dot4  = lasx_maddubs_h(q2_4, q8s_4);
1667
1668            const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));
1669            const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));
1670            const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));
1671            const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));
1672
1673            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));
1674            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));
1675            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));
1676            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));
1677        }
1678
1679        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1680
1681    }
1682
1683    *s = 0.125f * hsum_float_8(accumf);
1684
1685#else
1686    UNUSED(x);
1687    UNUSED(y);
1688    UNUSED(nb);
1689    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1690#endif
1691}
1692
1693void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1694    assert(n % QK_K == 0);
1695    assert(nrc == 1);
1696    UNUSED(nrc);
1697    UNUSED(bx);
1698    UNUSED(by);
1699    UNUSED(bs);
1700
1701    const block_iq2_s * GGML_RESTRICT x = vx;
1702    const block_q8_K  * GGML_RESTRICT y = vy;
1703
1704    const int nb = n / QK_K;
1705
1706#if defined(__loongarch_asx)
1707
1708   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
1709                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
1710   };
1711
1712    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1713                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1714    };
1715
1716
1717    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
1718    const __m128i m1 = __lsx_vreplgr2vr_b(1);
1719
1720    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
1721    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
1722    uint64_t aux64;
1723
1724    __m256 accumf = (__m256)__lasx_xvldi(0);
1725    for (int i = 0; i < nb; ++i) {
1726        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1727        const uint8_t * GGML_RESTRICT qs = x[i].qs;
1728        const uint8_t * GGML_RESTRICT qh = x[i].qh;
1729        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
1730        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1731
1732        __m128i tmp1;
1733        memcpy(&aux64, x[i].scales, 8);
1734        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);
1735        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);
1736        const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);
1737        const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
1738
1739        __m256i sumi1 = __lasx_xvldi(0);
1740        __m256i sumi2 = __lasx_xvldi(0);
1741        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1742            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1743            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1744            const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
1745                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
1746                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
1747                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
1748            const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
1749                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
1750                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
1751                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
1752            qs += 8;
1753
1754            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));
1755            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1756            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
1757            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
1758
1759            aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));
1760            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1761            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
1762            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
1763
1764            signs += 4;
1765
1766            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
1767            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
1768
1769            const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));
1770            const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));
1771            sumi1 = __lasx_xvadd_w(sumi1, p1);
1772            sumi2 = __lasx_xvadd_w(sumi2, p2);
1773        }
1774
1775        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1776    }
1777
1778    *s = 0.125f * hsum_float_8(accumf);
1779
1780#else
1781    UNUSED(x);
1782    UNUSED(y);
1783    UNUSED(nb);
1784    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1785#endif
1786}
1787
1788void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1789    assert(n % QK_K == 0);
1790    assert(nrc == 1);
1791    UNUSED(nrc);
1792    UNUSED(bx);
1793    UNUSED(by);
1794    UNUSED(bs);
1795
1796    const block_iq3_xxs * GGML_RESTRICT x = vx;
1797    const block_q8_K    * GGML_RESTRICT y = vy;
1798
1799    const int nb = n / QK_K;
1800
1801#if defined(__loongarch_asx)
1802
1803    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1804
1805    uint32_t aux32[2];
1806
1807    __m256 accumf = (__m256)__lasx_xvldi(0);
1808    for (int i = 0; i < nb; ++i) {
1809        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1810        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1811        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
1812        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1813        __m256i sumi1 = __lasx_xvldi(0);
1814        __m256i sumi2 = __lasx_xvldi(0);
1815        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1816            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1817            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1818            const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
1819                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
1820            q3 += 8;
1821            const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
1822                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
1823            q3 += 8;
1824            memcpy(aux32, gas, 8); gas += 8;
1825
1826            const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
1827                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
1828            const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
1829                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
1830            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
1831            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
1832            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
1833            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
1834            const uint16_t ls1 = aux32[0] >> 28;
1835            const uint16_t ls2 = aux32[1] >> 28;
1836
1837            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1838            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1839            sumi1 = __lasx_xvadd_w(sumi1, p1);
1840            sumi2 = __lasx_xvadd_w(sumi2, p2);
1841        }
1842
1843        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1844    }
1845
1846    *s = 0.25f * hsum_float_8(accumf);
1847
1848#else
1849    UNUSED(x);
1850    UNUSED(y);
1851    UNUSED(nb);
1852    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1853#endif
1854}
1855
1856void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1857    assert(n % QK_K == 0);
1858    assert(nrc == 1);
1859    UNUSED(nrc);
1860    UNUSED(bx);
1861    UNUSED(by);
1862    UNUSED(bs);
1863
1864    const block_iq3_s * GGML_RESTRICT x = vx;
1865    const block_q8_K  * GGML_RESTRICT y = vy;
1866
1867    const int nb = n / QK_K;
1868
1869#if defined(__loongarch_asx)
1870
1871   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
1872                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
1873   };
1874
1875    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1876                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1877    };
1878
1879    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
1880    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
1881
1882    __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);
1883    const __m256i idx_mask  = __lasx_xvreplgr2vr_w(256);
1884
1885    typedef union {
1886        __m256i  vec[2];
1887        uint32_t index[16];
1888    } index_t;
1889
1890    index_t idx;
1891
1892    __m256 accumf = (__m256)__lasx_xvldi(0);
1893    for (int i = 0; i < nb; ++i) {
1894        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1895        const uint8_t * GGML_RESTRICT qs = x[i].qs;
1896        const uint8_t * GGML_RESTRICT qh = x[i].qh;
1897        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
1898        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1899        __m256i sumi1 = __lasx_xvldi(0);
1900        __m256i sumi2 = __lasx_xvldi(0);
1901        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1902            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1903            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1904            const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;
1905            idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);
1906            idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);
1907            idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);
1908            idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);
1909            idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));
1910            idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));
1911
1912            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
1913            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
1914            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
1915            const __m256i q2_1 = lasx_set_w(
1916                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
1917                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
1918            );
1919            const __m256i q2_2 = lasx_set_w(
1920                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
1921                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
1922            );
1923
1924            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));
1925            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1926            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
1927            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
1928
1929            aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));
1930            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1931            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
1932            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
1933
1934            signs += 4;
1935
1936            const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1937            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
1938            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
1939            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
1940            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1941            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1942            sumi1 = __lasx_xvadd_w(sumi1, p1);
1943            sumi2 = __lasx_xvadd_w(sumi2, p2);
1944        }
1945
1946        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1947    }
1948
1949    *s = hsum_float_8(accumf);
1950
1951#else
1952    UNUSED(x);
1953    UNUSED(y);
1954    UNUSED(nb);
1955    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1956#endif
1957}
1958
1959#if defined(__loongarch_asx)
1960static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
1961    const __m256i a = __lasx_xvmulwev_h_b(x, y);
1962    const __m256i b = __lasx_xvmulwod_h_b(x, y);
1963    return __lasx_xvadd_h(a, b);
1964}
1965#endif
1966
1967void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1968    assert(n % QK_K == 0);
1969    assert(nrc == 1);
1970    UNUSED(nrc);
1971    UNUSED(bx);
1972    UNUSED(by);
1973    UNUSED(bs);
1974
1975    const block_iq1_s * GGML_RESTRICT x = vx;
1976    const block_q8_K  * GGML_RESTRICT y = vy;
1977
1978    const int nb = n / QK_K;
1979
1980#if defined(__loongarch_asx)
1981
1982    __m256 accum = (__m256)__lasx_xvldi(0);
1983    float accum1 = 0;
1984    for (int i = 0; i < nb; ++i) {
1985
1986        const int8_t   * q8 = y[i].qs;
1987        const uint8_t  * qs = x[i].qs;
1988        const uint16_t * qh = x[i].qh;
1989
1990        __m256i sumi = __lasx_xvldi(0);
1991        int sumi1 = 0;
1992        for (int ib = 0; ib < QK_K/32; ib += 2) {
1993            __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);
1994            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);
1995            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
1996            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
1997
1998            __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
1999            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
2000            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
2001            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);
2002
2003            qs += 8;
2004            const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
2005            const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
2006
2007            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
2008            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
2009            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
2010            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
2011
2012            __m256i tmp1, tmp5, tmp6;
2013            tmp1 = __lasx_xvreplgr2vr_h(ls1);
2014            tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);
2015            tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);
2016            const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);
2017
2018            tmp1 = __lasx_xvreplgr2vr_h(ls2);
2019            tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);
2020            tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);
2021            const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);
2022
2023            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));
2024            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
2025                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
2026        }
2027
2028        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2029        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
2030        accum1 += d * sumi1;
2031    }
2032
2033    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
2034
2035#else
2036    UNUSED(x);
2037    UNUSED(y);
2038    UNUSED(nb);
2039    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2040#endif
2041}
2042
2043void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2044    assert(nrc == 1);
2045    UNUSED(nrc);
2046    UNUSED(bx);
2047    UNUSED(by);
2048    UNUSED(bs);
2049    assert(n % QK4_NL == 0);
2050    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
2051
2052    const block_iq4_nl * GGML_RESTRICT x = vx;
2053    const block_q8_0   * GGML_RESTRICT y = vy;
2054
2055    const int nb = n / QK4_NL;
2056
2057    int ib = 0;
2058    float sumf = 0;
2059
2060#if defined (__loongarch_asx)
2061
2062    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2063    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
2064    const __m256i mone = __lasx_xvreplgr2vr_h(1);
2065
2066    __m256 accum1 = (__m256)__lasx_xvldi(0);
2067    __m256 accum2 = (__m256)__lasx_xvldi(0);
2068    for (; ib + 1 < nb; ib += 2) {
2069        const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
2070        const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
2071        const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
2072        const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
2073        const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
2074                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
2075        const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
2076                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
2077        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
2078        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
2079        const __m256i p_1 = lasx_madd_h(p16_1, mone);
2080        const __m256i p_2 = lasx_madd_h(p16_2, mone);
2081        accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
2082                __lasx_xvffint_s_w(p_1), accum1);
2083        accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
2084                __lasx_xvffint_s_w(p_2), accum2);
2085    }
2086
2087    sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
2088
2089#endif
2090    for (; ib < nb; ++ib) {
2091        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
2092        int sumi1 = 0, sumi2 = 0;
2093        for (int j = 0; j < QK4_NL/2; ++j) {
2094            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
2095            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
2096        }
2097        sumf += d * (sumi1 + sumi2);
2098    }
2099    *s = sumf;
2100}
2101
2102void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2103    assert(nrc == 1);
2104    UNUSED(nrc);
2105    UNUSED(bx);
2106    UNUSED(by);
2107    UNUSED(bs);
2108    assert(n % QK_K == 0);
2109
2110    const block_iq4_xs * GGML_RESTRICT x = vx;
2111    const block_q8_K   * GGML_RESTRICT y = vy;
2112
2113    const int nb = n / QK_K;
2114
2115#if defined(__loongarch_asx)
2116
2117    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2118
2119    __m256 accum = (__m256)__lasx_xvldi(0);
2120
2121    for (int ibl = 0; ibl < nb; ++ibl) {
2122        const uint8_t * qs = x[ibl].qs;
2123        const int8_t  * q8 = y[ibl].qs;
2124        uint16_t sh = x[ibl].scales_h;
2125        __m256i sumi1 = __lasx_xvldi(0);
2126        __m256i sumi2 = __lasx_xvldi(0);
2127        for (int ib = 0; ib < QK_K/32; ib += 2) {
2128            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2129            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2130            const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
2131            const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
2132            const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
2133                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
2134            const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
2135                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
2136            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
2137            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
2138            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
2139            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
2140            sh >>= 4;
2141            const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
2142            const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
2143            sumi1 = __lasx_xvadd_w(p_1, sumi1);
2144            sumi2 = __lasx_xvadd_w(p_2, sumi2);
2145        }
2146        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
2147                __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);
2148    }
2149
2150    *s = hsum_float_8(accum);
2151
2152#else
2153    UNUSED(x);
2154    UNUSED(y);
2155    UNUSED(nb);
2156    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2157#endif
2158}
2159