1#define GGML_COMMON_IMPL_C
   2#include "ggml-common.h"
   3#include "ggml-quants.h"
   4#include "ggml-impl.h"
   5#include "ggml-cpu.h"
   6#include "simd-mappings.h"
   7
   8#include "../../quants.h"
   9#include "../../ggml-cpu-impl.h"
  10
  11#include <math.h>
  12#include <string.h>
  13#include <assert.h>
  14#include <stdlib.h> // for qsort
  15#include <stdio.h>  // for GGML_ASSERT
  16
  17#define GROUP_MAX_EPS 1e-15f
  18#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
  19#define GROUP_MAX_EPS_IQ2_S 1e-8f
  20#define GROUP_MAX_EPS_IQ1_M 1e-7f
  21#define GROUP_MAX_EPS_IQ1_S 1e-12f
  22
  23#define UNUSED GGML_UNUSED
  24
  25// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
  26#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  27
  28#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
  29// multiply int8_t, add results pairwise twice
  30static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
  31    // Get absolute values of x vectors
  32    const __m128i ax = _mm_sign_epi8(x, x);
  33    // Sign the values of the y vectors
  34    const __m128i sy = _mm_sign_epi8(y, x);
  35    // Perform multiplication and create 16-bit values
  36    const __m128i dot = _mm_maddubs_epi16(ax, sy);
  37    const __m128i ones = _mm_set1_epi16(1);
  38    return _mm_madd_epi16(ones, dot);
  39}
  40
  41#if __AVX__ || __AVX2__ || __AVX512F__
  42// horizontally add 8 floats
  43static inline float hsum_float_8(const __m256 x) {
  44    __m128 res = _mm256_extractf128_ps(x, 1);
  45    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
  46    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
  47    res = _mm_add_ss(res, _mm_movehdup_ps(res));
  48    return _mm_cvtss_f32(res);
  49}
  50
  51// horizontally add 8 int32_t
  52static inline int hsum_i32_8(const __m256i a) {
  53    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
  54    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
  55    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
  56    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
  57    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
  58}
  59
  60// horizontally add 4 int32_t
  61static inline int hsum_i32_4(const __m128i a) {
  62    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
  63    const __m128i sum64 = _mm_add_epi32(hi64, a);
  64    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
  65    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
  66}
  67
  68#if defined(__AVX2__) || defined(__AVX512F__)
  69static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
  70    const __m256i ax = _mm256_sign_epi8(x, x);
  71    const __m256i sy = _mm256_sign_epi8(y, x);
  72    return _mm256_maddubs_epi16(ax, sy);
  73}
  74
  75// spread 32 bits to 32 bytes { 0x00, 0xFF }
  76static inline __m256i bytes_from_bits_32(const uint8_t * x) {
  77    uint32_t x32;
  78    memcpy(&x32, x, sizeof(uint32_t));
  79    const __m256i shuf_mask = _mm256_set_epi64x(
  80            0x0303030303030303, 0x0202020202020202,
  81            0x0101010101010101, 0x0000000000000000);
  82    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
  83    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
  84    bytes = _mm256_or_si256(bytes, bit_mask);
  85    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
  86}
  87
  88// Unpack 32 4-bit fields into 32 bytes
  89// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
  90static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
  91{
  92    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
  93    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
  94    const __m256i lowMask = _mm256_set1_epi8( 0xF );
  95    return _mm256_and_si256(lowMask, bytes);
  96}
  97
  98// add int16_t pairwise and return as float vector
  99static inline __m256 sum_i16_pairs_float(const __m256i x) {
 100    const __m256i ones = _mm256_set1_epi16(1);
 101    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
 102    return _mm256_cvtepi32_ps(summed_pairs);
 103}
 104
 105static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 106#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
 107    const __m256i zero = _mm256_setzero_si256();
 108    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
 109    return _mm256_cvtepi32_ps(summed_pairs);
 110#elif defined(__AVXVNNI__)
 111    const __m256i zero = _mm256_setzero_si256();
 112    const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
 113    return _mm256_cvtepi32_ps(summed_pairs);
 114#else
 115    // Perform multiplication and create 16-bit values
 116    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
 117    return sum_i16_pairs_float(dot);
 118#endif
 119}
 120
 121// multiply int8_t, add results pairwise twice and return as float vector
 122static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 123#if __AVXVNNIINT8__
 124    const __m256i zero = _mm256_setzero_si256();
 125    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
 126    return _mm256_cvtepi32_ps(summed_pairs);
 127#else
 128    // Get absolute values of x vectors
 129    const __m256i ax = _mm256_sign_epi8(x, x);
 130    // Sign the values of the y vectors
 131    const __m256i sy = _mm256_sign_epi8(y, x);
 132    return mul_sum_us8_pairs_float(ax, sy);
 133#endif
 134}
 135
 136static inline __m128i packNibbles( __m256i bytes )
 137{
 138    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
 139#if __AVX512F__
 140    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
 141    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
 142    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
 143#else
 144    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
 145    __m256i high = _mm256_andnot_si256( lowByte, bytes );
 146    __m256i low = _mm256_and_si256( lowByte, bytes );
 147    high = _mm256_srli_epi16( high, 4 );
 148    bytes = _mm256_or_si256( low, high );
 149
 150    // Compress uint16_t lanes into bytes
 151    __m128i r0 = _mm256_castsi256_si128( bytes );
 152    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
 153    return _mm_packus_epi16( r0, r1 );
 154#endif
 155}
 156#elif defined(__AVX__)
 157static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 158{
 159    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
 160    const __m128i lowByte = _mm_set1_epi16( 0xFF );
 161    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
 162    __m128i low = _mm_and_si128( lowByte, bytes1 );
 163    high = _mm_srli_epi16( high, 4 );
 164    bytes1 = _mm_or_si128( low, high );
 165    high = _mm_andnot_si128( lowByte, bytes2 );
 166    low = _mm_and_si128( lowByte, bytes2 );
 167    high = _mm_srli_epi16( high, 4 );
 168    bytes2 = _mm_or_si128( low, high );
 169
 170    return _mm_packus_epi16( bytes1, bytes2);
 171}
 172
 173static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
 174    const __m128i ax = _mm_sign_epi8(x, x);
 175    const __m128i sy = _mm_sign_epi8(y, x);
 176    return _mm_maddubs_epi16(ax, sy);
 177}
 178
 179// spread 32 bits to 32 bytes { 0x00, 0xFF }
 180static inline __m256i bytes_from_bits_32(const uint8_t * x) {
 181    uint32_t x32;
 182    memcpy(&x32, x, sizeof(uint32_t));
 183    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
 184    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
 185    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
 186    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
 187    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
 188    bytesl = _mm_or_si128(bytesl, bit_mask);
 189    bytesh = _mm_or_si128(bytesh, bit_mask);
 190    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
 191    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
 192    return MM256_SET_M128I(bytesh, bytesl);
 193}
 194
 195// Unpack 32 4-bit fields into 32 bytes
 196// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
 197static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
 198{
 199    // Load 16 bytes from memory
 200    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
 201    __m128i tmph = _mm_srli_epi16(tmpl, 4);
 202    const __m128i lowMask = _mm_set1_epi8(0xF);
 203    tmpl = _mm_and_si128(lowMask, tmpl);
 204    tmph = _mm_and_si128(lowMask, tmph);
 205    return MM256_SET_M128I(tmph, tmpl);
 206}
 207
 208// add int16_t pairwise and return as float vector
 209static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
 210    const __m128i ones = _mm_set1_epi16(1);
 211    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
 212    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
 213    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
 214    return _mm256_cvtepi32_ps(summed_pairs);
 215}
 216
 217static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 218    const __m128i axl = _mm256_castsi256_si128(ax);
 219    const __m128i axh = _mm256_extractf128_si256(ax, 1);
 220    const __m128i syl = _mm256_castsi256_si128(sy);
 221    const __m128i syh = _mm256_extractf128_si256(sy, 1);
 222    // Perform multiplication and create 16-bit values
 223    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
 224    const __m128i doth = _mm_maddubs_epi16(axh, syh);
 225    return sum_i16_pairs_float(doth, dotl);
 226}
 227
 228// multiply int8_t, add results pairwise twice and return as float vector
 229static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 230    const __m128i xl = _mm256_castsi256_si128(x);
 231    const __m128i xh = _mm256_extractf128_si256(x, 1);
 232    const __m128i yl = _mm256_castsi256_si128(y);
 233    const __m128i yh = _mm256_extractf128_si256(y, 1);
 234    // Get absolute values of x vectors
 235    const __m128i axl = _mm_sign_epi8(xl, xl);
 236    const __m128i axh = _mm_sign_epi8(xh, xh);
 237    // Sign the values of the y vectors
 238    const __m128i syl = _mm_sign_epi8(yl, xl);
 239    const __m128i syh = _mm_sign_epi8(yh, xh);
 240    // Perform multiplication and create 16-bit values
 241    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
 242    const __m128i doth = _mm_maddubs_epi16(axh, syh);
 243    return sum_i16_pairs_float(doth, dotl);
 244}
 245
 246// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
 247static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
 248                                           const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
 249    const __m128i mone = _mm_set1_epi16(1);
 250
 251    const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
 252    const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
 253    const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
 254    const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
 255    const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
 256    const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
 257    const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
 258    const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
 259    const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
 260    const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
 261    return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
 262}
 263
 264// quad fp16 delta calculation
 265static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
 266    // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C
 267    return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
 268                           _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 269}
 270
 271static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) {
 272    return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
 273                           _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 274}
 275#endif
 276#elif defined(__SSSE3__)
 277// horizontally add 4x4 floats
 278static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
 279    __m128 res_0 =_mm_hadd_ps(a, b);
 280    __m128 res_1 =_mm_hadd_ps(c, d);
 281    __m128 res =_mm_hadd_ps(res_0, res_1);
 282    res =_mm_hadd_ps(res, res);
 283    res =_mm_hadd_ps(res, res);
 284
 285    return _mm_cvtss_f32(res);
 286}
 287#endif // __AVX__ || __AVX2__ || __AVX512F__
 288#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 289
 290void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 291    assert(QK8_0 == 32);
 292    assert(k % QK8_0 == 0);
 293    const int nb = k / QK8_0;
 294
 295    block_q8_0 * GGML_RESTRICT y = vy;
 296
 297#if defined(__AVX2__) || defined(__AVX__)
 298    for (int i = 0; i < nb; i++) {
 299        // Load elements into 4 AVX vectors
 300        __m256 v0 = _mm256_loadu_ps( x );
 301        __m256 v1 = _mm256_loadu_ps( x + 8 );
 302        __m256 v2 = _mm256_loadu_ps( x + 16 );
 303        __m256 v3 = _mm256_loadu_ps( x + 24 );
 304        x += 32;
 305
 306        // Compute max(abs(e)) for the block
 307        const __m256 signBit = _mm256_set1_ps( -0.0f );
 308        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
 309        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
 310        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
 311        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
 312
 313        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
 314        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
 315        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
 316        const float maxScalar = _mm_cvtss_f32( max4 );
 317
 318        // Quantize these floats
 319        const float d = maxScalar / 127.f;
 320        y[i].d = GGML_CPU_FP32_TO_FP16(d);
 321        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
 322        const __m256 mul = _mm256_set1_ps( id );
 323
 324        // Apply the multiplier
 325        v0 = _mm256_mul_ps( v0, mul );
 326        v1 = _mm256_mul_ps( v1, mul );
 327        v2 = _mm256_mul_ps( v2, mul );
 328        v3 = _mm256_mul_ps( v3, mul );
 329
 330        // Round to nearest integer
 331        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
 332        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
 333        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
 334        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
 335
 336        // Convert floats to integers
 337        __m256i i0 = _mm256_cvtps_epi32( v0 );
 338        __m256i i1 = _mm256_cvtps_epi32( v1 );
 339        __m256i i2 = _mm256_cvtps_epi32( v2 );
 340        __m256i i3 = _mm256_cvtps_epi32( v3 );
 341
 342#if defined(__AVX2__)
 343        // Convert int32 to int16
 344        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
 345        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
 346                                            // Convert int16 to int8
 347        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
 348
 349        // We got our precious signed bytes, but the order is now wrong
 350        // These AVX2 pack instructions process 16-byte pieces independently
 351        // The following instruction is fixing the order
 352        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
 353        i0 = _mm256_permutevar8x32_epi32( i0, perm );
 354
 355        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
 356#else
 357        // Since we don't have in AVX some necessary functions,
 358        // we split the registers in half and call AVX2 analogs from SSE
 359        __m128i ni0 = _mm256_castsi256_si128( i0 );
 360        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
 361        __m128i ni2 = _mm256_castsi256_si128( i1 );
 362        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
 363        __m128i ni4 = _mm256_castsi256_si128( i2 );
 364        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
 365        __m128i ni6 = _mm256_castsi256_si128( i3 );
 366        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
 367
 368        // Convert int32 to int16
 369        ni0 = _mm_packs_epi32( ni0, ni1 );
 370        ni2 = _mm_packs_epi32( ni2, ni3 );
 371        ni4 = _mm_packs_epi32( ni4, ni5 );
 372        ni6 = _mm_packs_epi32( ni6, ni7 );
 373        // Convert int16 to int8
 374        ni0 = _mm_packs_epi16( ni0, ni2 );
 375        ni4 = _mm_packs_epi16( ni4, ni6 );
 376
 377        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
 378        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
 379#endif
 380    }
 381#else
 382    GGML_UNUSED(nb);
 383    // scalar
 384    quantize_row_q8_0_ref(x, y, k);
 385#endif
 386}
 387
 388void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 389    assert(k % QK8_1 == 0);
 390    const int nb = k / QK8_1;
 391
 392    block_q8_1 * GGML_RESTRICT y = vy;
 393#if defined(__AVX2__) || defined(__AVX__)
 394    for (int i = 0; i < nb; i++) {
 395        // Load elements into 4 AVX vectors
 396        __m256 v0 = _mm256_loadu_ps( x );
 397        __m256 v1 = _mm256_loadu_ps( x + 8 );
 398        __m256 v2 = _mm256_loadu_ps( x + 16 );
 399        __m256 v3 = _mm256_loadu_ps( x + 24 );
 400        x += 32;
 401
 402        // Compute max(abs(e)) for the block
 403        const __m256 signBit = _mm256_set1_ps( -0.0f );
 404        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
 405        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
 406        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
 407        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
 408
 409        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
 410        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
 411        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
 412        const float max_scalar = _mm_cvtss_f32( max4 );
 413
 414        // Quantize these floats
 415        const float d = max_scalar / 127.f;
 416        y[i].d = GGML_CPU_FP32_TO_FP16(d);
 417        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
 418        const __m256 mul = _mm256_set1_ps( id );
 419
 420        // Apply the multiplier
 421        v0 = _mm256_mul_ps( v0, mul );
 422        v1 = _mm256_mul_ps( v1, mul );
 423        v2 = _mm256_mul_ps( v2, mul );
 424        v3 = _mm256_mul_ps( v3, mul );
 425
 426        // Round to nearest integer
 427        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
 428        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
 429        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
 430        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
 431
 432        // Convert floats to integers
 433        __m256i i0 = _mm256_cvtps_epi32( v0 );
 434        __m256i i1 = _mm256_cvtps_epi32( v1 );
 435        __m256i i2 = _mm256_cvtps_epi32( v2 );
 436        __m256i i3 = _mm256_cvtps_epi32( v3 );
 437
 438#if defined(__AVX2__)
 439        // Compute the sum of the quants and set y[i].s
 440        y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
 441
 442        // Convert int32 to int16
 443        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
 444        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
 445                                            // Convert int16 to int8
 446        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
 447
 448        // We got our precious signed bytes, but the order is now wrong
 449        // These AVX2 pack instructions process 16-byte pieces independently
 450        // The following instruction is fixing the order
 451        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
 452        i0 = _mm256_permutevar8x32_epi32( i0, perm );
 453
 454        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
 455#else
 456        // Since we don't have in AVX some necessary functions,
 457        // we split the registers in half and call AVX2 analogs from SSE
 458        __m128i ni0 = _mm256_castsi256_si128( i0 );
 459        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
 460        __m128i ni2 = _mm256_castsi256_si128( i1 );
 461        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
 462        __m128i ni4 = _mm256_castsi256_si128( i2 );
 463        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
 464        __m128i ni6 = _mm256_castsi256_si128( i3 );
 465        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
 466
 467        // Compute the sum of the quants and set y[i].s
 468        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
 469        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
 470        y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
 471
 472        // Convert int32 to int16
 473        ni0 = _mm_packs_epi32( ni0, ni1 );
 474        ni2 = _mm_packs_epi32( ni2, ni3 );
 475        ni4 = _mm_packs_epi32( ni4, ni5 );
 476        ni6 = _mm_packs_epi32( ni6, ni7 );
 477        // Convert int16 to int8
 478        ni0 = _mm_packs_epi16( ni0, ni2 );
 479        ni4 = _mm_packs_epi16( ni4, ni6 );
 480
 481        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
 482        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
 483#endif
 484    }
 485#else
 486    GGML_UNUSED(nb);
 487    // scalar
 488    quantize_row_q8_1_ref(x, y, k);
 489#endif
 490}
 491
 492// placeholder implementation for Apple targets
 493void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
 494    quantize_row_q8_K_ref(x, y, k);
 495}
 496
 497//===================================== Dot products =================================
 498
 499//
 500// Helper functions
 501//
 502
 503#if __AVX__ || __AVX2__ || __AVX512F__
 504
 505// shuffles to pick the required scales in dot products
 506static inline __m256i get_scale_shuffle_q3k(int i) {
 507    static const uint8_t k_shuffle[128] = {
 508         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
 509         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
 510         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
 511        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
 512    };
 513    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
 514}
 515static inline __m256i get_scale_shuffle_k4(int i) {
 516    static const uint8_t k_shuffle[256] = {
 517         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
 518         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
 519         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
 520         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
 521         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
 522        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
 523        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
 524        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
 525    };
 526    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
 527}
 528static inline __m128i get_scale_shuffle(int i) {
 529    static const uint8_t k_shuffle[128] = {
 530         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
 531         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
 532         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
 533         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
 534         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
 535        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
 536        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
 537        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
 538    };
 539    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
 540}
 541#endif
 542
 543void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 544    const int qk = QK8_0;
 545    const int nb = n / qk;
 546
 547    assert(n % qk == 0);
 548    assert(nrc == 1);
 549    UNUSED(nrc);
 550    UNUSED(bx);
 551    UNUSED(by);
 552    UNUSED(bs);
 553
 554    const block_q4_0 * GGML_RESTRICT x = vx;
 555    const block_q8_0 * GGML_RESTRICT y = vy;
 556
 557    int ib = 0;
 558    float sumf = 0;
 559
 560#if defined(__AVX2__)
 561    // Initialize accumulator with zeros
 562    __m256 acc = _mm256_setzero_ps();
 563
 564    // Main loop
 565    for (; ib < nb; ++ib) {
 566        /* Compute combined scale for the block */
 567        const __m256 d = _mm256_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
 568
 569        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 570
 571        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
 572        const __m256i off = _mm256_set1_epi8( 8 );
 573        qx = _mm256_sub_epi8( qx, off );
 574
 575        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
 576
 577        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
 578
 579        /* Multiply q with scale and accumulate */
 580        acc = _mm256_fmadd_ps( d, q, acc );
 581    }
 582
 583    sumf = hsum_float_8(acc);
 584#elif defined(__AVX__)
 585    __m256 accum = _mm256_setzero_ps();
 586    for (; ib + 1 < nb; ib += 2) {
 587        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
 588        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
 589        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
 590        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
 591        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
 592        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
 593
 594        const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
 595        const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
 596        const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
 597        const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
 598
 599        const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
 600        const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
 601        const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
 602        const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
 603        const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
 604        const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
 605        const __m256 p =  sum_i16_pairs_float(p_2, p_1);
 606
 607        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
 608        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
 609    }
 610
 611    sumf = hsum_float_8(accum);
 612#elif defined(__SSSE3__)
 613    // set constants
 614    const __m128i lowMask = _mm_set1_epi8(0xF);
 615    const __m128i off = _mm_set1_epi8(8);
 616
 617    // Initialize accumulator with zeros
 618    __m128 acc_0 = _mm_setzero_ps();
 619    __m128 acc_1 = _mm_setzero_ps();
 620    __m128 acc_2 = _mm_setzero_ps();
 621    __m128 acc_3 = _mm_setzero_ps();
 622
 623    for (; ib + 1 < nb; ib += 2) {
 624        _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
 625        _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
 626
 627        // Compute combined scale for the block 0 and 1
 628        const __m128 d_0_1 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
 629
 630        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
 631
 632        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
 633        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
 634        bx_0 = _mm_sub_epi8(bx_0, off);
 635        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
 636
 637        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
 638        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
 639        bx_1 = _mm_sub_epi8(bx_1, off);
 640        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
 641
 642        _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
 643        _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
 644
 645        // Compute combined scale for the block 2 and 3
 646        const __m128 d_2_3 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
 647
 648        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
 649
 650        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
 651        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
 652        bx_2 = _mm_sub_epi8(bx_2, off);
 653        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
 654
 655        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
 656        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
 657        bx_3 = _mm_sub_epi8(bx_3, off);
 658        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
 659
 660        // Convert int32_t to float
 661        __m128 p0 = _mm_cvtepi32_ps(i32_0);
 662        __m128 p1 = _mm_cvtepi32_ps(i32_1);
 663        __m128 p2 = _mm_cvtepi32_ps(i32_2);
 664        __m128 p3 = _mm_cvtepi32_ps(i32_3);
 665
 666        // Apply the scale
 667        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
 668        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
 669        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
 670        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
 671
 672        // Acummulate
 673        acc_0 = _mm_add_ps(p0_d, acc_0);
 674        acc_1 = _mm_add_ps(p1_d, acc_1);
 675        acc_2 = _mm_add_ps(p2_d, acc_2);
 676        acc_3 = _mm_add_ps(p3_d, acc_3);
 677    }
 678
 679    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
 680
 681#endif
 682    for (; ib < nb; ++ib) {
 683        int sumi0 = 0;
 684        int sumi1 = 0;
 685
 686        for (int j = 0; j < qk/2; ++j) {
 687            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
 688            const int v1 = (x[ib].qs[j] >>   4) - 8;
 689
 690            sumi0 += (v0 * y[ib].qs[j]);
 691            sumi1 += (v1 * y[ib].qs[j + qk/2]);
 692        }
 693
 694        int sumi = sumi0 + sumi1;
 695        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
 696    }
 697
 698    *s = sumf;
 699}
 700
 701void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 702    const int qk = QK8_1;
 703    const int nb = n / qk;
 704
 705    assert(n % qk == 0);
 706    assert(nrc == 1);
 707    UNUSED(nrc);
 708    UNUSED(bx);
 709    UNUSED(by);
 710    UNUSED(bs);
 711
 712    const block_q4_1 * GGML_RESTRICT x = vx;
 713    const block_q8_1 * GGML_RESTRICT y = vy;
 714
 715    int ib = 0;
 716
 717#if defined(__AVX2__) || defined(__AVX__)
 718    // Initialize accumulator with zeros
 719    __m256 acc = _mm256_setzero_ps();
 720
 721    float summs = 0;
 722
 723    // Main loop
 724    for (; ib < nb; ++ib) {
 725        const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
 726        const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);
 727
 728        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
 729
 730        const __m256 d0v = _mm256_set1_ps( d0 );
 731        const __m256 d1v = _mm256_set1_ps( d1 );
 732
 733        // Compute combined scales
 734        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
 735
 736        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
 737        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 738        const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
 739
 740        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
 741
 742        // Accumulate d0*d1*x*y
 743#if defined(__AVX2__)
 744        acc = _mm256_fmadd_ps( d0d1, xy, acc );
 745#else
 746        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
 747#endif
 748    }
 749
 750    *s = hsum_float_8(acc) + summs;
 751#else
 752    UNUSED(nb);
 753    UNUSED(x);
 754    UNUSED(y);
 755    UNUSED(ib);
 756    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 757#endif
 758}
 759
 760void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 761    assert(nrc == 1);
 762    UNUSED(nrc);
 763    UNUSED(bx);
 764    UNUSED(by);
 765    UNUSED(bs);
 766    assert(n % QK_MXFP4 == 0);
 767    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
 768
 769    const block_mxfp4 * GGML_RESTRICT x = vx;
 770    const block_q8_0 * GGML_RESTRICT y = vy;
 771
 772    const int nb = n / QK_MXFP4;
 773
 774    int ib = 0;
 775    float sumf = 0;
 776
 777#if defined __AVX2__
 778
 779    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
 780    const __m128i m4b  = _mm_set1_epi8(0x0f);
 781    const __m256i mone = _mm256_set1_epi16(1);
 782
 783    __m256 accum1 = _mm256_setzero_ps();
 784    __m256 accum2 = _mm256_setzero_ps();
 785
 786    for (; ib + 1 < nb; ib += 2) {
 787        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
 788        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
 789        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
 790        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
 791        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
 792                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
 793        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
 794                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
 795        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
 796        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
 797        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
 798        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
 799        const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e));
 800        const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e));
 801        accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1);
 802        accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2);
 803    }
 804
 805    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
 806
 807#elif defined __AVX__
 808    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
 809    const __m128i m4b  = _mm_set1_epi8(0x0f);
 810
 811    __m256 accum = _mm256_setzero_ps();
 812    for (; ib + 1 < nb; ib += 2) {
 813        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
 814        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
 815        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
 816        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
 817        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
 818        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
 819
 820        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
 821        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
 822        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
 823        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
 824
 825        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
 826        const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
 827        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
 828    }
 829
 830    sumf = hsum_float_8(accum);
 831
 832#endif
 833    for (; ib < nb; ++ib) {
 834        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e);
 835        int sumi1 = 0;
 836        int sumi2 = 0;
 837        for (int j = 0; j < QK_MXFP4/2; ++j) {
 838            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
 839            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
 840        }
 841        sumf += d * (sumi1 + sumi2);
 842    }
 843    *s = sumf;
 844}
 845
 846void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 847    const int qk = QK8_0;
 848    const int nb = n / qk;
 849
 850    int ib = 0;
 851
 852    assert(n % qk == 0);
 853    assert(qk == QK5_0);
 854    assert(nrc == 1);
 855    UNUSED(nrc);
 856    UNUSED(bx);
 857    UNUSED(by);
 858    UNUSED(bs);
 859
 860    const block_q5_0 * GGML_RESTRICT x = vx;
 861    const block_q8_0 * GGML_RESTRICT y = vy;
 862
 863#if defined(__AVX2__)
 864    // Initialize accumulator with zeros
 865    __m256 acc = _mm256_setzero_ps();
 866
 867    // Main loop
 868    for (; ib < nb; ++ib) {
 869        /* Compute combined scale for the block */
 870        const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
 871
 872        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 873        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
 874        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
 875        qx = _mm256_or_si256(qx, bxhi);
 876
 877        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
 878
 879        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
 880
 881        /* Multiply q with scale and accumulate */
 882        acc = _mm256_fmadd_ps(d, q, acc);
 883    }
 884
 885    *s = hsum_float_8(acc);
 886#elif defined(__AVX__)
 887    // Initialize accumulator with zeros
 888    __m256 acc = _mm256_setzero_ps();
 889    __m128i mask = _mm_set1_epi8((char)0xF0);
 890
 891    // Main loop
 892    for (; ib < nb; ++ib) {
 893        /* Compute combined scale for the block */
 894        const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
 895
 896        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
 897        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
 898        __m128i bxhil = _mm256_castsi256_si128(bxhi);
 899        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
 900        bxhil = _mm_andnot_si128(bxhil, mask);
 901        bxhih = _mm_andnot_si128(bxhih, mask);
 902        __m128i bxl = _mm256_castsi256_si128(bx_0);
 903        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
 904        bxl = _mm_or_si128(bxl, bxhil);
 905        bxh = _mm_or_si128(bxh, bxhih);
 906        bx_0 = MM256_SET_M128I(bxh, bxl);
 907
 908        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
 909
 910        const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
 911
 912        /* Multiply q with scale and accumulate */
 913        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
 914    }
 915
 916    *s = hsum_float_8(acc);
 917#else
 918    UNUSED(nb);
 919    UNUSED(ib);
 920    UNUSED(x);
 921    UNUSED(y);
 922    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 923#endif
 924}
 925
 926void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
 927    const int qk = QK8_1;
 928    const int nb = n / qk;
 929
 930    int ib = 0;
 931
 932    assert(n % qk == 0);
 933    assert(qk == QK5_1);
 934    assert(nrc == 1);
 935    UNUSED(nrc);
 936    UNUSED(bx);
 937    UNUSED(by);
 938    UNUSED(bs);
 939
 940    const block_q5_1 * GGML_RESTRICT x = vx;
 941    const block_q8_1 * GGML_RESTRICT y = vy;
 942
 943#if defined(__AVX2__)
 944    // Initialize accumulator with zeros
 945    __m256 acc = _mm256_setzero_ps();
 946
 947    float summs = 0.0f;
 948
 949    // Main loop
 950    for (; ib < nb; ++ib) {
 951        const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
 952
 953        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
 954
 955        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
 956        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
 957        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
 958        qx = _mm256_or_si256(qx, bxhi);
 959
 960        const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));
 961        const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
 962
 963        const __m256 q = mul_sum_us8_pairs_float(qx, qy);
 964
 965        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
 966    }
 967
 968    *s = hsum_float_8(acc) + summs;
 969#elif defined(__AVX__)
 970    // Initialize accumulator with zeros
 971    __m256 acc = _mm256_setzero_ps();
 972    __m128i mask = _mm_set1_epi8(0x10);
 973
 974    float summs = 0.0f;
 975
 976    // Main loop
 977    for (; ib < nb; ++ib) {
 978        const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
 979
 980        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
 981
 982        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
 983        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
 984        __m128i bxhil = _mm256_castsi256_si128(bxhi);
 985        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
 986        bxhil = _mm_and_si128(bxhil, mask);
 987        bxhih = _mm_and_si128(bxhih, mask);
 988        __m128i bxl = _mm256_castsi256_si128(bx_0);
 989        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
 990        bxl = _mm_or_si128(bxl, bxhil);
 991        bxh = _mm_or_si128(bxh, bxhih);
 992        bx_0 = MM256_SET_M128I(bxh, bxl);
 993
 994        const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));
 995        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
 996
 997        const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
 998
 999        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
1000    }
1001
1002    *s = hsum_float_8(acc) + summs;
1003#else
1004    UNUSED(nb);
1005    UNUSED(ib);
1006    UNUSED(x);
1007    UNUSED(y);
1008    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
1009#endif
1010}
1011
1012void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1013    const int qk = QK8_0;
1014    const int nb = n / qk;
1015
1016    assert(n % qk == 0);
1017    assert(nrc == 1);
1018    UNUSED(nrc);
1019    UNUSED(bx);
1020    UNUSED(by);
1021    UNUSED(bs);
1022
1023    const block_q8_0 * GGML_RESTRICT x = vx;
1024    const block_q8_0 * GGML_RESTRICT y = vy;
1025
1026    int ib = 0;
1027    float sumf = 0;
1028
1029#if defined(__AVX2__)
1030    // Initialize accumulator with zeros
1031    __m256 acc = _mm256_setzero_ps();
1032
1033    // Main loop
1034    for (; ib < nb; ++ib) {
1035        // Compute combined scale for the block
1036        const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
1037        __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
1038        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
1039
1040        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
1041
1042        // Multiply q with scale and accumulate
1043        acc = _mm256_fmadd_ps( d, q, acc );
1044    }
1045
1046    sumf = hsum_float_8(acc);
1047#elif defined(__AVX__)
1048    __m256 accum = _mm256_setzero_ps();
1049
1050    for (; ib + 1 < nb; ib += 2) {
1051        const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
1052        const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
1053        const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
1054        const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
1055        const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
1056        const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
1057        const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
1058        const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
1059
1060        const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
1061        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
1062        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
1063    }
1064
1065    sumf = hsum_float_8(accum);
1066#endif
1067    for (; ib < nb; ++ib) {
1068        int sumi = 0;
1069
1070        for (int j = 0; j < qk; j++) {
1071            sumi += x[ib].qs[j]*y[ib].qs[j];
1072        }
1073
1074        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
1075    }
1076
1077    *s = sumf;
1078}
1079
1080void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1081    assert(nrc == 1);
1082    UNUSED(nrc);
1083    UNUSED(bx);
1084    UNUSED(by);
1085    UNUSED(bs);
1086
1087    const block_tq1_0 * GGML_RESTRICT x = vx;
1088    const block_q8_K  * GGML_RESTRICT y = vy;
1089
1090    const int nb = n / QK_K;
1091
1092#if defined(__AVX2__)
1093    __m256 sumf = _mm256_setzero_ps();
1094
1095    for (int i = 0; i < nb; ++i) {
1096        // 16-bit sums
1097        __m256i sumi0 = _mm256_setzero_si256();
1098        __m256i sumi1 = _mm256_setzero_si256();
1099        __m256i sumi2 = _mm256_setzero_si256();
1100
1101        // first 32 bytes of 5 elements
1102        {
1103            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
1104            // 8-bit multiplies with shifts, masks and adds
1105            __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
1106            __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
1107            __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
1108            __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
1109
1110            // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
1111
1112            // Cancel the +1 from avg so that it behaves like a halving add
1113            qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
1114            qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
1115            qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
1116            qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
1117            qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
1118            // Multiply by 3 and get the top 2 bits
1119            qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
1120            qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
1121            qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
1122            qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
1123            qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
1124            qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
1125            qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
1126            qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
1127            qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
1128            qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
1129
1130            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs +   0));
1131            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  32));
1132            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  64));
1133            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  96));
1134            const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
1135
1136            qx0 = _mm256_maddubs_epi16(qx0, qy0);
1137            qx1 = _mm256_maddubs_epi16(qx1, qy1);
1138            qx2 = _mm256_maddubs_epi16(qx2, qy2);
1139            qx3 = _mm256_maddubs_epi16(qx3, qy3);
1140            qx4 = _mm256_maddubs_epi16(qx4, qy4);
1141
1142            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
1143            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
1144            sumi2 = _mm256_add_epi16(sumi2, qx4);
1145        }
1146
1147        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1148        {
1149            __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
1150            uint32_t qh;
1151            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
1152            __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
1153            __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
1154            __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
1155            __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
1156            __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
1157            __m256i qx01 = MM256_SET_M128I(qx1, qx0);
1158            __m256i qx23 = MM256_SET_M128I(qx3, qx2);
1159
1160            // avx2 does not have 8-bit multiplies, so 16-bit it is.
1161            qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
1162            qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
1163            __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
1164
1165            __m256i qx45 = MM256_SET_M128I(qx5, qx4);
1166
1167            // Cancel the +1 from avg so that it behaves like a halving add
1168            qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
1169            qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
1170            qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
1171            // Multiply by 3 and get the top 2 bits
1172            qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
1173            qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
1174            qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
1175            qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
1176            qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
1177            qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
1178
1179            const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
1180            const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
1181            const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
1182
1183            qx01 = _mm256_maddubs_epi16(qx01, qy01);
1184            qx23 = _mm256_maddubs_epi16(qx23, qy23);
1185            qx45 = _mm256_maddubs_epi16(qx45, qy45);
1186
1187            sumi0 = _mm256_add_epi16(sumi0, qx01);
1188            sumi1 = _mm256_add_epi16(sumi1, qx23);
1189            sumi2 = _mm256_add_epi16(sumi2, qx45);
1190        }
1191
1192        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
1193        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));
1194
1195        sumi0 = _mm256_sub_epi16(sumi0, ysum);
1196        sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
1197        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
1198
1199        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
1200    }
1201
1202    *s = hsum_float_8(sumf);
1203
1204#else
1205    UNUSED(x);
1206    UNUSED(y);
1207    UNUSED(nb);
1208    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1209#endif
1210}
1211
1212void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1213    assert(nrc == 1);
1214    UNUSED(nrc);
1215    UNUSED(bx);
1216    UNUSED(by);
1217    UNUSED(bs);
1218
1219    const block_tq2_0 * GGML_RESTRICT x = vx;
1220    const block_q8_K  * GGML_RESTRICT y = vy;
1221
1222    const int nb = n / QK_K;
1223
1224#if defined(__AVX2__)
1225    __m256 sumf = _mm256_setzero_ps();
1226
1227    for (int i = 0; i < nb; ++i) {
1228        // 16-bit sums, because 256*127 still fits
1229        __m256i sumi0 = _mm256_setzero_si256();
1230        __m256i sumi1 = _mm256_setzero_si256();
1231
1232        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1233            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
1234            __m256i qx1 = _mm256_srli_epi16(qx0, 2);
1235            __m256i qx2 = _mm256_srli_epi16(qx0, 4);
1236            __m256i qx3 = _mm256_srli_epi16(qx0, 6);
1237
1238            // 0, 1, 2 (should not be 3)
1239            qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
1240            qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
1241            qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
1242            qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
1243
1244            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 +  0));
1245            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
1246            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
1247            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
1248
1249            qx0 = _mm256_maddubs_epi16(qx0, qy0);
1250            qx1 = _mm256_maddubs_epi16(qx1, qy1);
1251            qx2 = _mm256_maddubs_epi16(qx2, qy2);
1252            qx3 = _mm256_maddubs_epi16(qx3, qy3);
1253
1254            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
1255            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
1256        }
1257
1258        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
1259        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));
1260
1261        sumi0 = _mm256_add_epi16(sumi0, sumi1);
1262        sumi0 = _mm256_sub_epi16(sumi0, ysum);
1263        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
1264
1265        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
1266    }
1267
1268    *s = hsum_float_8(sumf);
1269
1270#else
1271    UNUSED(x);
1272    UNUSED(y);
1273    UNUSED(nb);
1274    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1275#endif
1276}
1277
1278void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1279    assert(nrc == 1);
1280    UNUSED(nrc);
1281    UNUSED(bx);
1282    UNUSED(by);
1283    UNUSED(bs);
1284
1285    const block_q2_K * GGML_RESTRICT x = vx;
1286    const block_q8_K * GGML_RESTRICT y = vy;
1287
1288    const int nb = n / QK_K;
1289
1290#if defined __AVX2__
1291
1292    const __m256i m3 = _mm256_set1_epi8(3);
1293    const __m128i m4 = _mm_set1_epi8(0xF);
1294
1295    __m256 acc = _mm256_setzero_ps();
1296
1297    for (int i = 0; i < nb; ++i) {
1298
1299        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1300        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1301
1302        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1303        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1304
1305        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1306        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
1307        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1308        const __m256i mins = _mm256_cvtepi8_epi16(mins8);
1309        const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
1310
1311        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
1312
1313        const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
1314        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1315        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1316        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1317
1318        __m256i sumi = _mm256_setzero_si256();
1319
1320        for (int j = 0; j < QK_K/128; ++j) {
1321
1322            const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
1323
1324            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1325            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1326            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1327            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1328
1329            const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
1330            const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
1331            const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
1332            const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
1333
1334            __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
1335            __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
1336            __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
1337            __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
1338
1339            p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
1340            p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
1341            p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
1342            p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
1343
1344            p0 = _mm256_add_epi32(p0, p1);
1345            p2 = _mm256_add_epi32(p2, p3);
1346
1347            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
1348        }
1349
1350        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1351
1352    }
1353
1354    *s = hsum_float_8(acc);
1355
1356#elif defined __AVX__
1357
1358    const __m128i m3 = _mm_set1_epi8(0x3);
1359    const __m128i m4 = _mm_set1_epi8(0xF);
1360    const __m128i m2 = _mm_set1_epi8(0x2);
1361
1362    __m256 acc = _mm256_setzero_ps();
1363
1364    for (int i = 0; i < nb; ++i) {
1365
1366        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1367        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1368
1369        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1370        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1371
1372        // load mins and scales from block_q2_K.scales[QK_K/16]
1373        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1374        const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
1375        const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1376        const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
1377        const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
1378
1379        // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
1380        const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
1381        const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1382
1383        // sumf += -dmin * summs in 32bits*8
1384        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
1385
1386        const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1387        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
1388        const __m128i scales[2] = { scales_0, scales_1 };
1389
1390        __m128i sumi_0 = _mm_setzero_si128();
1391        __m128i sumi_1 = _mm_setzero_si128();
1392
1393        for (int j = 0; j < QK_K/128; ++j) {
1394
1395            // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
1396            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1397            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1398            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1399            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1400            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1401            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1402            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1403            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1404
1405            // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
1406            __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1407            const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1408            const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1409            const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1410            const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1411            q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1412            const __m128i q2_1 = _mm_and_si128(q2bits, m3);
1413            const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1414            const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1415            const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1416
1417            // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
1418            __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
1419            __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
1420            __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
1421            __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
1422            __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
1423            __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
1424            __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
1425            __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
1426
1427            // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
1428            __m128i shuffle = _mm_set1_epi16(0x0100);
1429            p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
1430            shuffle = _mm_add_epi16(shuffle, m2);
1431            p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
1432            shuffle = _mm_add_epi16(shuffle, m2);
1433            p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
1434            shuffle = _mm_add_epi16(shuffle, m2);
1435            p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
1436            shuffle = _mm_add_epi16(shuffle, m2);
1437            p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
1438            shuffle = _mm_add_epi16(shuffle, m2);
1439            p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
1440            shuffle = _mm_add_epi16(shuffle, m2);
1441            p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
1442            shuffle = _mm_add_epi16(shuffle, m2);
1443            p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
1444
1445            p0 = _mm_add_epi32(p0, p1);
1446            p2 = _mm_add_epi32(p2, p3);
1447            p4 = _mm_add_epi32(p4, p5);
1448            p6 = _mm_add_epi32(p6, p7);
1449
1450            // isum in 32bits*4*2
1451            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
1452            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
1453        }
1454
1455        // sumf += dall * isum - dmin * summs in 32bits
1456        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1457        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1458    }
1459
1460    *s = hsum_float_8(acc);
1461
1462#else
1463    UNUSED(x);
1464    UNUSED(y);
1465    UNUSED(nb);
1466    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1467#endif
1468}
1469
1470void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1471    assert(n % QK_K == 0);
1472    assert(nrc == 1);
1473    UNUSED(nrc);
1474    UNUSED(bx);
1475    UNUSED(by);
1476    UNUSED(bs);
1477
1478    const uint32_t kmask1 = 0x03030303;
1479    const uint32_t kmask2 = 0x0f0f0f0f;
1480
1481    const block_q3_K * GGML_RESTRICT x = vx;
1482    const block_q8_K * GGML_RESTRICT y = vy;
1483
1484    const int nb = n / QK_K;
1485
1486#if defined __AVX2__
1487
1488    const __m256i m3 = _mm256_set1_epi8(3);
1489    const __m256i mone = _mm256_set1_epi8(1);
1490    const __m128i m32 = _mm_set1_epi8(32);
1491
1492    __m256 acc = _mm256_setzero_ps();
1493
1494    uint32_t aux[3];
1495
1496    for (int i = 0; i < nb; ++i) {
1497
1498        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1499
1500        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1501        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1502
1503        // Set up scales
1504        memcpy(aux, x[i].scales, 12);
1505        __m128i scales128 = _mm_set_epi32(
1506                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1507                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1508                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1509                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1510        scales128 = _mm_sub_epi8(scales128, m32);
1511        const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
1512        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1513        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1514        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1515
1516        // high bit
1517        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
1518
1519        // integer accumulator
1520        __m256i sumi = _mm256_setzero_si256();
1521
1522        int bit = 0;
1523        int is  = 0;
1524
1525        for (int j = 0; j < QK_K/128; ++j) {
1526            // load low 2 bits
1527            const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
1528
1529            // prepare low and high bits
1530            const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
1531            const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1532            ++bit;
1533
1534            const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
1535            const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1536            ++bit;
1537
1538            const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
1539            const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1540            ++bit;
1541
1542            const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
1543            const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1544            ++bit;
1545
1546            // load Q8 quants
1547            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1548            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1549            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1550            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1551
1552            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1553            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1554            // and 2 if the high bit was set)
1555            __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
1556            __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
1557            __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
1558            __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
1559
1560            __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
1561            __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
1562            __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
1563            __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
1564
1565            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
1566            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
1567            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
1568            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
1569
1570            // multiply with scales
1571            p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
1572            p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
1573            p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
1574            p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
1575
1576            // accumulate
1577            p16_0 = _mm256_add_epi32(p16_0, p16_1);
1578            p16_2 = _mm256_add_epi32(p16_2, p16_3);
1579            sumi  = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
1580
1581        }
1582
1583        // multiply with block scale and accumulate
1584        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1585
1586    }
1587
1588    *s = hsum_float_8(acc);
1589
1590#elif defined __AVX__
1591
1592    const __m128i m3 = _mm_set1_epi8(3);
1593    const __m128i mone = _mm_set1_epi8(1);
1594    const __m128i m32 = _mm_set1_epi8(32);
1595    const __m128i m2 = _mm_set1_epi8(2);
1596
1597    __m256 acc = _mm256_setzero_ps();
1598
1599    const uint32_t *aux;
1600
1601    for (int i = 0; i < nb; ++i) {
1602
1603        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1604
1605        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1606        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1607
1608        // Set up scales
1609        aux = (const uint32_t *)x[i].scales;
1610        __m128i scales128 = _mm_set_epi32(
1611                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1612                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1613                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1614                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1615        scales128 = _mm_sub_epi8(scales128, m32);
1616        const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
1617        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
1618        const __m128i scales[2] = { scales_0, scales_1 };
1619
1620        // high bit *128*2 from block_q3_K.hmask[QK_K/8]
1621        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
1622        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
1623
1624        // integer accumulator
1625        __m128i sumi_0 = _mm_setzero_si128();
1626        __m128i sumi_1 = _mm_setzero_si128();
1627
1628        for (int j = 0; j < QK_K/128; ++j) {
1629            // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
1630            const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1631            const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1632
1633            // prepare low and high bits
1634            const int bit = j << 2;
1635
1636            const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
1637            const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
1638            const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
1639            const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
1640
1641            const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
1642            const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
1643            const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1644            const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1645
1646            const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
1647            const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
1648            const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1649            const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1650
1651            const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
1652            const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
1653            const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1654            const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1655
1656            // load Q8 quants from block_q8_K.qs[QK_K]
1657            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1658            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1659            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1660            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1661            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1662            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1663            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1664            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1665
1666            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1667            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1668            // and 2 if the high bit was set)
1669            __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
1670            __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
1671            __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
1672            __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
1673            __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
1674            __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
1675            __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
1676            __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
1677
1678            __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
1679            __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
1680            __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
1681            __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
1682            __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
1683            __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
1684            __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
1685            __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
1686
1687            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
1688            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
1689            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
1690            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
1691            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
1692            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
1693            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
1694            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
1695
1696            // multiply with scales
1697            __m128i shuffle = _mm_set1_epi16(0x0100);
1698            p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
1699            shuffle = _mm_add_epi16(shuffle, m2);
1700            p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
1701            shuffle = _mm_add_epi16(shuffle, m2);
1702            p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
1703            shuffle = _mm_add_epi16(shuffle, m2);
1704            p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
1705            shuffle = _mm_add_epi16(shuffle, m2);
1706            p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
1707            shuffle = _mm_add_epi16(shuffle, m2);
1708            p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
1709            shuffle = _mm_add_epi16(shuffle, m2);
1710            p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
1711            shuffle = _mm_add_epi16(shuffle, m2);
1712            p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
1713
1714            // accumulate
1715            p16_0 = _mm_add_epi32(p16_0, p16_1);
1716            p16_2 = _mm_add_epi32(p16_2, p16_3);
1717            p16_4 = _mm_add_epi32(p16_4, p16_5);
1718            p16_6 = _mm_add_epi32(p16_6, p16_7);
1719            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
1720            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
1721
1722        }
1723
1724        // multiply with block scale and accumulate
1725        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1726        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
1727
1728    }
1729
1730    *s = hsum_float_8(acc);
1731
1732#else
1733    UNUSED(kmask1);
1734    UNUSED(kmask2);
1735    UNUSED(x);
1736    UNUSED(y);
1737    UNUSED(nb);
1738    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1739#endif
1740}
1741
1742void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1743    assert(n % QK_K == 0);
1744    assert(nrc == 1);
1745    UNUSED(nrc);
1746    UNUSED(bx);
1747    UNUSED(by);
1748    UNUSED(bs);
1749
1750    const block_q4_K * GGML_RESTRICT x = vx;
1751    const block_q8_K * GGML_RESTRICT y = vy;
1752
1753    const int nb = n / QK_K;
1754
1755    static const uint32_t kmask1 = 0x3f3f3f3f;
1756    static const uint32_t kmask2 = 0x0f0f0f0f;
1757    static const uint32_t kmask3 = 0x03030303;
1758
1759    uint32_t utmp[4];
1760
1761#if defined __AVX2__
1762
1763    const __m256i m4 = _mm256_set1_epi8(0xF);
1764
1765    __m256 acc = _mm256_setzero_ps();
1766    __m128 acc_m = _mm_setzero_ps();
1767
1768   for (int i = 0; i < nb; ++i) {
1769
1770        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1771        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1772
1773        memcpy(utmp, x[i].scales, 12);
1774        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1775        const uint32_t uaux = utmp[1] & kmask1;
1776        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1777        utmp[2] = uaux;
1778        utmp[0] &= kmask1;
1779
1780        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1781        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1782
1783        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1784
1785        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1786        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1787        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1788        acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
1789
1790        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
1791        const __m256i scales = MM256_SET_M128I(sc128, sc128);
1792
1793        __m256i sumi = _mm256_setzero_si256();
1794
1795        for (int j = 0; j < QK_K/64; ++j) {
1796
1797            const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1798            const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1799
1800            const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
1801            const __m256i q4l = _mm256_and_si256(q4bits, m4);
1802            const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
1803
1804            const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1805            __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
1806            p16l = _mm256_madd_epi16(scale_l, p16l);
1807
1808            const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1809            __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
1810            p16h = _mm256_madd_epi16(scale_h, p16h);
1811            const __m256i sumj = _mm256_add_epi32(p16l, p16h);
1812
1813            sumi = _mm256_add_epi32(sumi, sumj);
1814        }
1815
1816        __m256 vd = _mm256_set1_ps(d);
1817        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
1818
1819    }
1820
1821    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1822    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1823
1824    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1825
1826#elif defined __AVX__
1827
1828    const __m128i m4 = _mm_set1_epi8(0xF);
1829    const __m128i m2 = _mm_set1_epi8(0x2);
1830
1831    __m256 acc = _mm256_setzero_ps();
1832    __m128 acc_m = _mm_setzero_ps();
1833
1834   for (int i = 0; i < nb; ++i) {
1835
1836        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1837        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1838
1839        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1840        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1841
1842        memcpy(utmp, x[i].scales, 12);
1843        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1844        const uint32_t uaux = utmp[1] & kmask1;
1845        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1846        utmp[2] = uaux;
1847        utmp[0] &= kmask1;
1848
1849        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
1850        const __m128i scales = _mm_cvtepu8_epi16(utmps);
1851        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
1852
1853        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
1854        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
1855        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
1856        const __m128i prod = _mm_madd_epi16(mins, q8s);
1857        acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
1858
1859        __m128i sumi_0 = _mm_setzero_si128();
1860        __m128i sumi_1 = _mm_setzero_si128();
1861
1862        __m128i shuffle = _mm_set1_epi16(0x0100);
1863        for (int j = 0; j < QK_K/64; ++j) {
1864
1865            const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
1866            shuffle = _mm_add_epi16(shuffle, m2);
1867            const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
1868            shuffle = _mm_add_epi16(shuffle, m2);
1869
1870            __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1871            const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
1872            const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1873            q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1874            const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
1875            const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1876
1877            const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1878            __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
1879            p16l = _mm_madd_epi16(scale_l, p16l);
1880            sumi_0 = _mm_add_epi32(sumi_0, p16l);
1881            const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1882            p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
1883            p16l = _mm_madd_epi16(scale_l, p16l);
1884            sumi_1 = _mm_add_epi32(sumi_1, p16l);
1885
1886            const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1887            __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
1888            p16h = _mm_madd_epi16(scale_h, p16h);
1889            sumi_0 = _mm_add_epi32(sumi_0, p16h);
1890            const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1891            p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
1892            p16h = _mm_madd_epi16(scale_h, p16h);
1893            sumi_1 = _mm_add_epi32(sumi_1, p16h);
1894
1895        }
1896
1897        __m256 vd = _mm256_set1_ps(d);
1898        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1899        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
1900
1901    }
1902
1903    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1904    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1905
1906    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1907
1908#else
1909    UNUSED(x);
1910    UNUSED(y);
1911    UNUSED(nb);
1912    UNUSED(kmask1);
1913    UNUSED(kmask2);
1914    UNUSED(kmask3);
1915    UNUSED(utmp);
1916    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1917#endif
1918}
1919
1920void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
1921    assert(n % QK_K == 0);
1922    assert(nrc == 1);
1923    UNUSED(nrc);
1924    UNUSED(bx);
1925    UNUSED(by);
1926    UNUSED(bs);
1927
1928    const block_q5_K * GGML_RESTRICT x = vx;
1929    const block_q8_K * GGML_RESTRICT y = vy;
1930
1931    const int nb = n / QK_K;
1932
1933    static const uint32_t kmask1 = 0x3f3f3f3f;
1934    static const uint32_t kmask2 = 0x0f0f0f0f;
1935    static const uint32_t kmask3 = 0x03030303;
1936
1937    uint32_t utmp[4];
1938
1939#if defined __AVX2__
1940
1941    const __m256i m4 = _mm256_set1_epi8(0xF);
1942    const __m128i mzero = _mm_setzero_si128();
1943    const __m256i mone  = _mm256_set1_epi8(1);
1944
1945    __m256 acc = _mm256_setzero_ps();
1946
1947    float summs = 0.f;
1948
1949    for (int i = 0; i < nb; ++i) {
1950        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1951        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
1952
1953        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1954        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1955
1956        memcpy(utmp, x[i].scales, 12);
1957        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1958        const uint32_t uaux = utmp[1] & kmask1;
1959        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1960        utmp[2] = uaux;
1961        utmp[0] &= kmask1;
1962
1963        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1964
1965        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1966        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1967        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1968        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
1969        summs += dmin * _mm_extract_epi32(hsum, 0);
1970
1971        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
1972        const __m256i scales = MM256_SET_M128I(sc128, sc128);
1973
1974        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
1975        __m256i hmask = mone;
1976
1977        __m256i sumi = _mm256_setzero_si256();
1978
1979        int bit = 0;
1980
1981        for (int j = 0; j < QK_K/64; ++j) {
1982
1983            const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1984            const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1985
1986            const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
1987
1988            const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
1989            const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1990            const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);
1991            hmask = _mm256_slli_epi16(hmask, 1);
1992
1993            const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
1994            const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1995            const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);
1996            hmask = _mm256_slli_epi16(hmask, 1);
1997
1998            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1999            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2000
2001            __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
2002            __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
2003
2004            p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2005            p16_1 = _mm256_madd_epi16(scale_1, p16_1);
2006
2007            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2008
2009        }
2010
2011        __m256 vd = _mm256_set1_ps(d);
2012        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
2013
2014    }
2015
2016    *s = hsum_float_8(acc) + summs;
2017
2018#elif defined __AVX__
2019
2020    const __m128i m4 = _mm_set1_epi8(0xF);
2021    const __m128i mzero = _mm_setzero_si128();
2022    const __m128i mone  = _mm_set1_epi8(1);
2023    const __m128i m2 = _mm_set1_epi8(2);
2024
2025    __m256 acc = _mm256_setzero_ps();
2026
2027    float summs = 0.f;
2028
2029    for (int i = 0; i < nb; ++i) {
2030
2031        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2032        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2033
2034        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
2035        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2036
2037        memcpy(utmp, x[i].scales, 12);
2038        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2039        const uint32_t uaux = utmp[1] & kmask1;
2040        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2041        utmp[2] = uaux;
2042        utmp[0] &= kmask1;
2043
2044        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
2045        const __m128i scales = _mm_cvtepu8_epi16(utmps);
2046        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
2047
2048        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
2049        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
2050        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
2051        const __m128i prod = _mm_madd_epi16(mins, q8s);
2052        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
2053        summs += dmin * _mm_extract_epi32(hsum, 0);
2054
2055        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
2056        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
2057        __m128i hmask = mone;
2058
2059        __m128i sumi_0 = _mm_setzero_si128();
2060        __m128i sumi_1 = _mm_setzero_si128();
2061
2062        int bit = 0;
2063
2064        __m128i shuffle = _mm_set1_epi16(0x0100);
2065        for (int j = 0; j < QK_K/64; ++j) {
2066
2067            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
2068            shuffle = _mm_add_epi16(shuffle, m2);
2069            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
2070            shuffle = _mm_add_epi16(shuffle, m2);
2071
2072            const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2073            const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2074
2075            __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
2076            __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
2077            __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2078            __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2079            __m128i q5_0  = _mm_add_epi8(q5l_0, q5h_0);
2080            __m128i q5_1  = _mm_add_epi8(q5l_1, q5h_1);
2081            hmask = _mm_slli_epi16(hmask, 1);
2082
2083            __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2084            __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2085            __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
2086            __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
2087            p16_0 = _mm_madd_epi16(scale_0, p16_0);
2088            p16_1 = _mm_madd_epi16(scale_0, p16_1);
2089
2090            q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
2091            q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
2092            q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2093            q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2094            q5_0  = _mm_add_epi8(q5l_0, q5h_0);
2095            q5_1  = _mm_add_epi8(q5l_1, q5h_1);
2096            hmask = _mm_slli_epi16(hmask, 1);
2097
2098            q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2099            q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2100            __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
2101            __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
2102            p16_2 = _mm_madd_epi16(scale_1, p16_2);
2103            p16_3 = _mm_madd_epi16(scale_1, p16_3);
2104
2105            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2106            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2107
2108        }
2109
2110        __m256 vd = _mm256_set1_ps(d);
2111        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2112        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2113
2114    }
2115
2116    *s = hsum_float_8(acc) + summs;
2117
2118#else
2119    UNUSED(x);
2120    UNUSED(y);
2121    UNUSED(nb);
2122    UNUSED(kmask1);
2123    UNUSED(kmask2);
2124    UNUSED(kmask3);
2125    UNUSED(utmp);
2126    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2127#endif
2128}
2129
2130void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2131    assert(n % QK_K == 0);
2132    assert(nrc == 1);
2133    UNUSED(nrc);
2134    UNUSED(bx);
2135    UNUSED(by);
2136    UNUSED(bs);
2137
2138    const block_q6_K * GGML_RESTRICT x = vx;
2139    const block_q8_K * GGML_RESTRICT y = vy;
2140
2141    const int nb = n / QK_K;
2142
2143#if defined __AVX2__
2144
2145    const __m256i m4 = _mm256_set1_epi8(0xF);
2146    const __m256i m2 = _mm256_set1_epi8(3);
2147    const __m256i m32s = _mm256_set1_epi8(32);
2148
2149    __m256 acc = _mm256_setzero_ps();
2150
2151    for (int i = 0; i < nb; ++i) {
2152
2153        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2154
2155        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2156        const uint8_t * GGML_RESTRICT qh = x[i].qh;
2157        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2158
2159        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2160
2161        __m256i sumi = _mm256_setzero_si256();
2162
2163        int is = 0;
2164
2165        for (int j = 0; j < QK_K/128; ++j) {
2166
2167            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2168            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2169            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2170            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2171            is += 4;
2172
2173            const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2174            const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2175            const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
2176
2177            const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
2178            const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
2179            const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
2180            const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
2181
2182            const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
2183            const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
2184            const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
2185            const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
2186
2187            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2188            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2189            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2190            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2191
2192            __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
2193            __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
2194            __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
2195            __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
2196
2197            __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
2198            __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
2199            __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
2200            __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
2201
2202            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2203            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2204            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
2205            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
2206
2207            p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
2208            p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
2209            p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
2210            p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
2211
2212            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2213            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
2214
2215        }
2216
2217        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
2218    }
2219
2220    *s = hsum_float_8(acc);
2221
2222#elif defined __AVX__
2223
2224    const __m128i m3 = _mm_set1_epi8(3);
2225    const __m128i m15 = _mm_set1_epi8(15);
2226
2227    __m256 acc = _mm256_setzero_ps();
2228
2229    for (int i = 0; i < nb; ++i) {
2230
2231        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2232
2233        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2234        const uint8_t * GGML_RESTRICT qh = x[i].qh;
2235        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2236
2237        // handle the q6_k -32 offset separately using bsums
2238        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
2239        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
2240        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2241        const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
2242        const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
2243        const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
2244        const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
2245
2246        __m128i sumi_0 = _mm_setzero_si128();
2247        __m128i sumi_1 = _mm_setzero_si128();
2248
2249        int is = 0;
2250
2251        for (int j = 0; j < QK_K/128; ++j) {
2252
2253            const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2254            const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2255
2256            const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
2257            const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
2258            const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
2259            const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
2260            const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
2261            const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
2262            const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
2263            const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
2264
2265            const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2266            const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2267            const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2268            const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2269
2270            const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
2271            const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
2272            const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
2273            const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
2274            const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
2275            const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
2276            const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
2277            const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
2278
2279            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2280            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2281            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2282            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2283            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2284            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2285            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2286            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2287
2288            __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
2289            __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
2290            __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
2291            __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
2292            __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
2293            __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
2294            __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
2295            __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
2296
2297            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2298            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2299            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2300            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2301            is += 4;
2302
2303            p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
2304            p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
2305            p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
2306            p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
2307            p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
2308            p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
2309            p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
2310            p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
2311
2312            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2313            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2314            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
2315            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
2316
2317        }
2318
2319        sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
2320        sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
2321        const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2322        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
2323    }
2324
2325    *s = hsum_float_8(acc);
2326
2327#else
2328    UNUSED(x);
2329    UNUSED(y);
2330    UNUSED(nb);
2331    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2332#endif
2333}
2334
2335#if defined (__AVX__) || defined (__AVX2__)
2336static const int8_t keven_signs_q2xs[1024] = {
2337     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
2338     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
2339     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
2340     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
2341     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
2342     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
2343     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
2344     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
2345     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
2346     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
2347     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
2348     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
2349     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
2350     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
2351     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
2352     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
2353     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
2354     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
2355     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
2356     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
2357     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
2358     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
2359     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
2360     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
2361     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
2362     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
2363     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
2364     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
2365     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
2366     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
2367     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
2368     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
2369};
2370#endif
2371
2372void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2373    assert(n % QK_K == 0);
2374    assert(nrc == 1);
2375    UNUSED(nrc);
2376    UNUSED(bx);
2377    UNUSED(by);
2378    UNUSED(bs);
2379
2380    const block_iq2_xxs * GGML_RESTRICT x = vx;
2381    const block_q8_K    * GGML_RESTRICT y = vy;
2382
2383    const int nb = n / QK_K;
2384
2385#if defined(__AVX2__)
2386
2387    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2388
2389    uint32_t aux32[4];
2390    const uint8_t * aux8 = (const uint8_t *)aux32;
2391
2392    __m256 accumf = _mm256_setzero_ps();
2393    for (int i = 0; i < nb; ++i) {
2394        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2395        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2396        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
2397        __m256i sumi1 = _mm256_setzero_si256();
2398        __m256i sumi2 = _mm256_setzero_si256();
2399        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2400            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2401            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2402            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2403            const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2404            const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2405            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
2406                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
2407            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
2408                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
2409            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
2410            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
2411            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
2412            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
2413            const uint16_t ls1 = aux32[1] >> 28;
2414            const uint16_t ls2 = aux32[3] >> 28;
2415            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
2416            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
2417            sumi1 = _mm256_add_epi32(sumi1, p1);
2418            sumi2 = _mm256_add_epi32(sumi2, p2);
2419        }
2420
2421        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2422
2423    }
2424
2425    *s = 0.125f * hsum_float_8(accumf);
2426
2427#elif defined(__AVX__)
2428    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2429
2430    uint32_t aux32[4];
2431    const uint8_t * aux8 = (const uint8_t *)aux32;
2432
2433    __m256 accumf = _mm256_setzero_ps();
2434    for (int i = 0; i < nb; ++i) {
2435        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2436        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2437        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
2438        __m128i sumi1_0 = _mm_setzero_si128();
2439        __m128i sumi1_1 = _mm_setzero_si128();
2440        __m128i sumi2_0 = _mm_setzero_si128();
2441        __m128i sumi2_1 = _mm_setzero_si128();
2442        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2443            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2444            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2445            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2446            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2447            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2448            const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2449            const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
2450            const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2451            const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
2452            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
2453            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
2454            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
2455            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
2456            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
2457            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
2458            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
2459            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
2460            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2461            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2462            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2463            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2464            const uint16_t ls1 = aux32[1] >> 28;
2465            const uint16_t ls2 = aux32[3] >> 28;
2466            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
2467            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
2468            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
2469            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
2470            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2471            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2472            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2473            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2474        }
2475
2476        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2477
2478    }
2479
2480    *s = 0.125f * hsum_float_8(accumf);
2481
2482#else
2483    UNUSED(x);
2484    UNUSED(y);
2485    UNUSED(nb);
2486    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2487#endif
2488}
2489
2490void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2491    assert(n % QK_K == 0);
2492    assert(nrc == 1);
2493    UNUSED(nrc);
2494    UNUSED(bx);
2495    UNUSED(by);
2496    UNUSED(bs);
2497
2498    const block_iq2_xs * GGML_RESTRICT x = vx;
2499    const block_q8_K   * GGML_RESTRICT y = vy;
2500
2501    const int nb = n / QK_K;
2502
2503#if defined(__AVX2__)
2504
2505    const __m256i mone = _mm256_set1_epi8(1);
2506    static const char block_sign_shuffle_mask_1[32] = {
2507        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2508        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2509    };
2510    static const char block_sign_shuffle_mask_2[32] = {
2511        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2512        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2513    };
2514    static const uint8_t bit_selector_mask_bytes[32] = {
2515        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2516        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2517    };
2518
2519    const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
2520    const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
2521    const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
2522
2523    static const uint8_t k_bit_helper[32] = {
2524        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2525        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2526    };
2527    const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
2528    const __m256i m511 = _mm256_set1_epi16(511);
2529    const __m128i m4 = _mm_set1_epi8(0xf);
2530    const __m128i m1 = _mm_set1_epi8(1);
2531
2532    uint64_t aux64;
2533
2534    // somewhat hacky, but gives a significant boost in performance
2535    __m256i aux_gindex;
2536    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2537
2538    __m256 accumf = _mm256_setzero_ps();
2539    for (int i = 0; i < nb; ++i) {
2540        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2541        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2542        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
2543
2544        memcpy(&aux64, x[i].scales, 8);
2545        __m128i stmp = _mm_set1_epi64x(aux64);
2546        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2547        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2548
2549        __m256i sumi1 = _mm256_setzero_si256();
2550        __m256i sumi2 = _mm256_setzero_si256();
2551        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2552
2553            const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2);  q2 += 16;
2554            aux_gindex = _mm256_and_si256(q2_data, m511);
2555
2556            const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
2557            const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
2558            const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
2559
2560            const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
2561            const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
2562
2563            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2564            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2565            const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2566            const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2567
2568            const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
2569                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
2570            const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
2571                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
2572            const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
2573                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
2574            const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
2575                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2576
2577            const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
2578            const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
2579            const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
2580            const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
2581
2582            __m256i signs;
2583            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
2584            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2585            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
2586
2587            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
2588            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2589            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
2590
2591            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
2592            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2593            const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
2594
2595            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
2596            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2597            const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
2598
2599            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
2600            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
2601            const __m256i dot3  = _mm256_maddubs_epi16(q2_3, q8s_3);
2602            const __m256i dot4  = _mm256_maddubs_epi16(q2_4, q8s_4);
2603
2604            const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
2605            const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
2606            const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
2607            const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
2608
2609            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
2610            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
2611            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
2612            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
2613        }
2614
2615        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2616
2617    }
2618
2619    *s = 0.125f * hsum_float_8(accumf);
2620
2621#elif defined(__AVX__)
2622    const __m128i mone = _mm_set1_epi8(1);
2623    static const char block_sign_shuffle_mask_1[32] = {
2624        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2625        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2626    };
2627    static const char block_sign_shuffle_mask_2[32] = {
2628        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2629        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2630    };
2631    static const uint8_t bit_selector_mask_bytes[32] = {
2632        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2633        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2634    };
2635
2636    const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
2637    const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
2638    const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
2639    const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
2640    const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
2641    const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
2642
2643    static const uint8_t k_bit_helper[32] = {
2644        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2645        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2646    };
2647    const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
2648    const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
2649    const __m128i m511 = _mm_set1_epi16(511);
2650    const __m128i m4 = _mm_set1_epi8(0xf);
2651    const __m128i m1 = _mm_set1_epi8(1);
2652
2653    uint64_t aux64;
2654
2655    // somewhat hacky, but gives a significant boost in performance
2656    __m256i aux_gindex;
2657    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2658
2659    __m256 accumf = _mm256_setzero_ps();
2660    for (int i = 0; i < nb; ++i) {
2661        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2662        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2663        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
2664
2665        memcpy(&aux64, x[i].scales, 8);
2666        __m128i stmp = _mm_set1_epi64x(aux64);
2667        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2668        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2669
2670        __m128i sumi1_0 = _mm_setzero_si128();
2671        __m128i sumi1_1 = _mm_setzero_si128();
2672        __m128i sumi2_0 = _mm_setzero_si128();
2673        __m128i sumi2_1 = _mm_setzero_si128();
2674        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2675
2676            const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
2677            const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1);  q2 += 16;
2678            aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
2679
2680            const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
2681            const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
2682            const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
2683            const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
2684            const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
2685            const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
2686
2687            const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
2688            const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
2689            const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
2690            const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
2691
2692            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2693            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2694            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2695            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2696            const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2697            const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2698            const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2699            const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2700
2701            const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
2702            const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
2703            const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
2704            const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
2705            const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
2706            const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
2707            const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2708            const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
2709
2710            // AVX2 full_signs_1 is full_sign_bits_0 here
2711            // AVX2 full_signs_2 is full_sign_bits_1 here
2712            __m128i signs_0, signs_1;
2713            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
2714            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
2715            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2716            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2717            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
2718            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
2719
2720            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
2721            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
2722            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2723            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2724            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
2725            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
2726
2727            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
2728            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
2729            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2730            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2731            const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
2732            const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
2733
2734            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
2735            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
2736            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2737            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2738            const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
2739            const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
2740
2741            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2742            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2743            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2744            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2745            const __m128i dot3_0  = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
2746            const __m128i dot3_1  = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
2747            const __m128i dot4_0  = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
2748            const __m128i dot4_1  = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
2749
2750            __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
2751            const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
2752            const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2753            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
2754            const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
2755            const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2756            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
2757            const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
2758            const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2759            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
2760            const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
2761            const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2762
2763            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
2764            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
2765            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
2766            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
2767            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
2768            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
2769            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
2770            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
2771        }
2772
2773        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2774
2775    }
2776
2777    *s = 0.125f * hsum_float_8(accumf);
2778
2779#else
2780    UNUSED(x);
2781    UNUSED(y);
2782    UNUSED(nb);
2783    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2784#endif
2785}
2786
2787void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2788    assert(n % QK_K == 0);
2789    assert(nrc == 1);
2790    UNUSED(nrc);
2791    UNUSED(bx);
2792    UNUSED(by);
2793    UNUSED(bs);
2794
2795    const block_iq2_s * GGML_RESTRICT x = vx;
2796    const block_q8_K  * GGML_RESTRICT y = vy;
2797
2798    const int nb = n / QK_K;
2799
2800#if defined(__AVX2__)
2801
2802   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
2803                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
2804   };
2805
2806    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2807                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2808    };
2809
2810    const __m128i m4 = _mm_set1_epi8(0xf);
2811    const __m128i m1 = _mm_set1_epi8(1);
2812
2813    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
2814    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
2815
2816    uint64_t aux64;
2817
2818    __m256 accumf = _mm256_setzero_ps();
2819    for (int i = 0; i < nb; ++i) {
2820        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2821        const uint8_t * GGML_RESTRICT qs = x[i].qs;
2822        const uint8_t * GGML_RESTRICT qh = x[i].qh;
2823        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
2824        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2825
2826        memcpy(&aux64, x[i].scales, 8);
2827        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
2828        const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
2829
2830        __m256i sumi1 = _mm256_setzero_si256();
2831        __m256i sumi2 = _mm256_setzero_si256();
2832        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2833            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2834            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2835            const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
2836                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
2837                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
2838                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
2839            const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
2840                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
2841                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
2842                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
2843            qs += 8;
2844
2845            __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
2846            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
2847            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
2848            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
2849
2850            aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
2851            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
2852            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
2853            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
2854
2855            signs += 4;
2856
2857            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
2858            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
2859
2860            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));
2861            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));
2862            sumi1 = _mm256_add_epi32(sumi1, p1);
2863            sumi2 = _mm256_add_epi32(sumi2, p2);
2864        }
2865
2866        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2867
2868    }
2869
2870    *s = 0.125f * hsum_float_8(accumf);
2871
2872#elif defined(__AVX__)
2873   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
2874                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
2875   };
2876
2877    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2878                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2879    };
2880
2881    const __m128i m4 = _mm_set1_epi8(0xf);
2882    const __m128i m1 = _mm_set1_epi8(1);
2883
2884    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
2885    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
2886    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
2887    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
2888
2889    uint64_t aux64;
2890
2891    __m256 accumf = _mm256_setzero_ps();
2892    for (int i = 0; i < nb; ++i) {
2893        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2894        const uint8_t * GGML_RESTRICT qs = x[i].qs;
2895        const uint8_t * GGML_RESTRICT qh = x[i].qh;
2896        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
2897        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2898
2899        memcpy(&aux64, x[i].scales, 8);
2900        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
2901        const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
2902        const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
2903
2904        __m128i sumi1_0 = _mm_setzero_si128();
2905        __m128i sumi1_1 = _mm_setzero_si128();
2906        __m128i sumi2_0 = _mm_setzero_si128();
2907        __m128i sumi2_1 = _mm_setzero_si128();
2908        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2909            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2910            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2911            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2912            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2913            const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
2914                                                  iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
2915            const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
2916                                                  iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
2917            const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
2918                                                  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
2919            const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
2920                                                  iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
2921            qs += 8;
2922
2923            __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
2924            __m128i aux128_1 = aux128_0;
2925            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
2926            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
2927            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
2928            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
2929            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
2930            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
2931
2932            aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
2933            aux128_1 = aux128_0;
2934            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
2935            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
2936            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
2937            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
2938            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
2939            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
2940
2941            signs += 4;
2942
2943            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2944            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2945            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2946            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2947
2948            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
2949            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
2950            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
2951            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
2952            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2953            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2954            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2955            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2956        }
2957
2958        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2959
2960    }
2961
2962    *s = 0.125f * hsum_float_8(accumf);
2963
2964#else
2965    UNUSED(x);
2966    UNUSED(y);
2967    UNUSED(nb);
2968    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2969#endif
2970}
2971
2972void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2973    assert(n % QK_K == 0);
2974    assert(nrc == 1);
2975    UNUSED(nrc);
2976    UNUSED(bx);
2977    UNUSED(by);
2978    UNUSED(bs);
2979
2980    const block_iq3_xxs * GGML_RESTRICT x = vx;
2981    const block_q8_K    * GGML_RESTRICT y = vy;
2982
2983    const int nb = n / QK_K;
2984
2985#if defined(__AVX2__)
2986
2987    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2988
2989    uint32_t aux32[2];
2990
2991    __m256 accumf = _mm256_setzero_ps();
2992    for (int i = 0; i < nb; ++i) {
2993        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2994        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
2995        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
2996        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
2997        __m256i sumi1 = _mm256_setzero_si256();
2998        __m256i sumi2 = _mm256_setzero_si256();
2999        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3000            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3001            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3002            const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
3003                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3004            q3 += 8;
3005            const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
3006                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3007            q3 += 8;
3008            memcpy(aux32, gas, 8); gas += 8;
3009            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
3010                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
3011            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
3012                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
3013            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
3014            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
3015            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
3016            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
3017            const uint16_t ls1 = aux32[0] >> 28;
3018            const uint16_t ls2 = aux32[1] >> 28;
3019            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
3020            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
3021            sumi1 = _mm256_add_epi32(sumi1, p1);
3022            sumi2 = _mm256_add_epi32(sumi2, p2);
3023        }
3024
3025        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3026
3027    }
3028
3029    *s = 0.25f * hsum_float_8(accumf);
3030
3031#elif defined(__AVX__)
3032    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3033
3034    uint32_t aux32[2];
3035
3036    __m256 accumf = _mm256_setzero_ps();
3037    for (int i = 0; i < nb; ++i) {
3038        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3039        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
3040        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
3041        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
3042        __m128i sumi1_0 = _mm_setzero_si128();
3043        __m128i sumi1_1 = _mm_setzero_si128();
3044        __m128i sumi2_0 = _mm_setzero_si128();
3045        __m128i sumi2_1 = _mm_setzero_si128();
3046        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3047            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3048            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3049            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3050            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3051            const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3052            const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3053            q3 += 8;
3054            const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3055            const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3056            q3 += 8;
3057            memcpy(aux32, gas, 8); gas += 8;
3058            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
3059            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
3060            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
3061            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
3062            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
3063            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
3064            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
3065            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
3066            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3067            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3068            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3069            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3070            const uint16_t ls1 = aux32[0] >> 28;
3071            const uint16_t ls2 = aux32[1] >> 28;
3072            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3073            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3074            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3075            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3076            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3077            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3078            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3079            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3080        }
3081
3082        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3083
3084    }
3085
3086    *s = 0.25f * hsum_float_8(accumf);
3087
3088#else
3089    UNUSED(x);
3090    UNUSED(y);
3091    UNUSED(nb);
3092    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3093#endif
3094}
3095
3096void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3097    assert(n % QK_K == 0);
3098    assert(nrc == 1);
3099    UNUSED(nrc);
3100    UNUSED(bx);
3101    UNUSED(by);
3102    UNUSED(bs);
3103
3104    const block_iq3_s * GGML_RESTRICT x = vx;
3105    const block_q8_K  * GGML_RESTRICT y = vy;
3106
3107    const int nb = n / QK_K;
3108
3109#if defined(__AVX2__)
3110
3111   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3112                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3113   };
3114
3115    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3116                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3117    };
3118
3119    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
3120    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
3121
3122    const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
3123    const __m256i idx_mask  = _mm256_set1_epi32(256);
3124
3125    typedef union {
3126        __m256i  vec[2];
3127        uint32_t index[16];
3128    } index_t;
3129
3130    index_t idx;
3131
3132    __m256 accumf = _mm256_setzero_ps();
3133    for (int i = 0; i < nb; ++i) {
3134        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3135        const uint8_t * GGML_RESTRICT qs = x[i].qs;
3136        const uint8_t * GGML_RESTRICT qh = x[i].qh;
3137        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3138        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
3139        __m256i sumi1 = _mm256_setzero_si256();
3140        __m256i sumi2 = _mm256_setzero_si256();
3141        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3142            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3143            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3144            const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
3145            idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
3146            idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
3147            idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
3148            idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
3149            idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
3150            idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
3151
3152            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
3153            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
3154            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
3155            const __m256i q2_1 = _mm256_set_epi32(
3156                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
3157                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
3158            );
3159            const __m256i q2_2 = _mm256_set_epi32(
3160                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
3161                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
3162            );
3163
3164            __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
3165            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3166            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
3167            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
3168
3169            aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
3170            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3171            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
3172            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
3173
3174            signs += 4;
3175
3176            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
3177            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
3178            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3179            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
3180            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
3181            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
3182            sumi1 = _mm256_add_epi32(sumi1, p1);
3183            sumi2 = _mm256_add_epi32(sumi2, p2);
3184        }
3185
3186        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3187
3188    }
3189
3190    *s = hsum_float_8(accumf);
3191
3192#elif defined(__AVX__)
3193   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3194                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3195   };
3196
3197    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3198                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3199    };
3200
3201    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
3202    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
3203    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
3204    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
3205
3206    const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
3207    const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
3208    const __m128i idx_mask  = _mm_set1_epi32(256);
3209
3210    typedef union {
3211        __m128i  vec[4];
3212        uint32_t index[16];
3213    } index_t;
3214
3215    index_t idx;
3216
3217    __m256 accumf = _mm256_setzero_ps();
3218    for (int i = 0; i < nb; ++i) {
3219        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3220        const uint8_t * GGML_RESTRICT qs = x[i].qs;
3221        const uint8_t * GGML_RESTRICT qh = x[i].qh;
3222        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3223        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
3224        __m128i sumi1_0 = _mm_setzero_si128();
3225        __m128i sumi1_1 = _mm_setzero_si128();
3226        __m128i sumi2_0 = _mm_setzero_si128();
3227        __m128i sumi2_1 = _mm_setzero_si128();
3228        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3229            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3230            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3231            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3232            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3233            const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
3234            const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
3235            const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
3236            idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
3237            idx.vec[1] = idx.vec[0];
3238            idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
3239            idx.vec[3] = idx.vec[2];
3240
3241            idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
3242            idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
3243            idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
3244            idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
3245
3246            idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
3247            idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
3248            idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
3249            idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
3250
3251            const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
3252            const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
3253            const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
3254            const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
3255
3256            __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
3257            __m128i aux128_1 = aux128_0;
3258            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3259            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3260            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3261            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3262            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
3263            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
3264
3265            aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
3266            aux128_1 = aux128_0;
3267            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3268            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3269            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3270            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3271            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
3272            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
3273
3274            signs += 4;
3275
3276            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3277            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3278            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3279            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3280            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3281            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
3282            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3283            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3284            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3285            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3286            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3287            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3288            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3289            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3290        }
3291
3292        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3293
3294    }
3295
3296    *s = hsum_float_8(accumf);
3297
3298#else
3299    UNUSED(x);
3300    UNUSED(y);
3301    UNUSED(nb);
3302    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3303#endif
3304}
3305
3306void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3307    assert(n % QK_K == 0);
3308    assert(nrc == 1);
3309    UNUSED(nrc);
3310    UNUSED(bx);
3311    UNUSED(by);
3312    UNUSED(bs);
3313
3314    const block_iq1_s * GGML_RESTRICT x = vx;
3315    const block_q8_K  * GGML_RESTRICT y = vy;
3316
3317    const int nb = n / QK_K;
3318
3319#if defined __AVX2__
3320
3321    __m256 accum = _mm256_setzero_ps();
3322    float accum1 = 0;
3323    for (int i = 0; i < nb; ++i) {
3324
3325        const int8_t   * q8 = y[i].qs;
3326        const uint8_t  * qs = x[i].qs;
3327        const uint16_t * qh = x[i].qh;
3328
3329        __m256i sumi = _mm256_setzero_si256();
3330        int sumi1 = 0;
3331        for (int ib = 0; ib < QK_K/32; ib += 2) {
3332#ifdef __BMI2__
3333            const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL);
3334            const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL);
3335            const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3336            const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3337            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
3338            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
3339#else
3340            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
3341                                                    iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3342            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
3343                                                    iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3344#endif
3345            qs += 8;
3346            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3347            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3348
3349            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
3350            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
3351            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3352            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3353            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
3354            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
3355
3356            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
3357            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3358                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3359        }
3360
3361        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
3362        accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
3363        accum1 += d * sumi1;
3364
3365    }
3366
3367    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3368
3369#elif defined __AVX__
3370    __m256 accum = _mm256_setzero_ps();
3371    float accum1 = 0;
3372    for (int i = 0; i < nb; ++i) {
3373
3374        const int8_t   * q8 = y[i].qs;
3375        const uint8_t  * qs = x[i].qs;
3376        const uint16_t * qh = x[i].qh;
3377
3378        __m128i sumi1_0 = _mm_setzero_si128();
3379        __m128i sumi1_1 = _mm_setzero_si128();
3380        int sumi1 = 0;
3381        for (int ib = 0; ib < QK_K/32; ib += 2) {
3382            const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3383            const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
3384            const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3385            const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
3386            qs += 8;
3387            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3388            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3389            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3390            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3391
3392            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3393            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3394            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3395            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3396            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3397            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3398            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
3399            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
3400            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
3401            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
3402
3403            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3404            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3405            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3406                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3407        }
3408
3409        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
3410        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
3411        accum1 += d * sumi1;
3412
3413    }
3414
3415    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3416
3417#else
3418    UNUSED(x);
3419    UNUSED(y);
3420    UNUSED(nb);
3421    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3422#endif
3423}
3424
3425void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3426    assert(n % QK_K == 0);
3427    assert(nrc == 1);
3428    UNUSED(nrc);
3429    UNUSED(bx);
3430    UNUSED(by);
3431    UNUSED(bs);
3432
3433    const block_iq1_m * GGML_RESTRICT x = vx;
3434    const block_q8_K  * GGML_RESTRICT y = vy;
3435
3436    const int nb = n / QK_K;
3437
3438    iq1m_scale_t scale;
3439
3440#if defined __AVX2__
3441
3442    const __m256i mask = _mm256_set1_epi16(0x7);
3443    const __m256i mone = _mm256_set1_epi16(1);
3444    const __m256i mone8 = _mm256_set1_epi8(1);
3445    const __m256i mtwo8 = _mm256_set1_epi8(2);
3446    // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
3447    const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
3448
3449    __m256 accum1 = _mm256_setzero_ps();
3450    __m256 accum2 = _mm256_setzero_ps();
3451    for (int i = 0; i < nb; ++i) {
3452
3453        const int8_t   * q8 = y[i].qs;
3454        const uint8_t  * qs = x[i].qs;
3455        const uint8_t  * qh = x[i].qh;
3456        const uint16_t * sc = (const uint16_t *)x[i].scales;
3457
3458        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3459        // Extract 3-bit scales (16 values)
3460        __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
3461        scales = _mm256_srlv_epi64(scales, scales_shift);
3462        scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
3463
3464        // Indices to repeat each scale 8 times.
3465        __m256i scales_idx1 = _mm256_set1_epi16(0x0100);
3466        __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
3467
3468        __m256i sumi1 = _mm256_setzero_si256();
3469        __m256i sumi2 = _mm256_setzero_si256();
3470        for (int ib = 0; ib < QK_K/32; ib += 2) {
3471#ifdef __BMI2__
3472            const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL)
3473                                       | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL);
3474            const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL)
3475                                       | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL);
3476            const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3477            const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3478            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
3479            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
3480
3481            // Convert signs to bytes 0x81 (negative) or 0x01 (positive)
3482            const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL);
3483            const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign)));
3484            const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32)));
3485#else
3486            const __m256i q1b_1 = _mm256_set_epi64x(
3487                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
3488                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
3489            );
3490            const __m256i q1b_2 = _mm256_set_epi64x(
3491                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
3492                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
3493            );
3494
3495            const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3496                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3497                                                     qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3498                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3499            const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3500                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3501                                                     qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3502                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3503#endif
3504            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3505            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3506
3507            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
3508            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
3509            const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
3510            const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
3511
3512            __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
3513            __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
3514
3515            scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
3516            scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
3517
3518            const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
3519            const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
3520            const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
3521            const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
3522
3523            sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
3524            sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
3525
3526            qs += 8; qh += 4;
3527        }
3528
3529        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));
3530
3531        accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
3532        accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
3533    }
3534
3535    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
3536
3537#elif defined __AVX__
3538    const __m128i mask = _mm_set1_epi16(0x7);
3539    const __m128i mone = _mm_set1_epi16(1);
3540
3541    __m256 accum1 = _mm256_setzero_ps();
3542    __m256 accum2 = _mm256_setzero_ps();
3543    for (int i = 0; i < nb; ++i) {
3544
3545        const int8_t   * q8 = y[i].qs;
3546        const uint8_t  * qs = x[i].qs;
3547        const uint8_t  * qh = x[i].qh;
3548        const uint16_t * sc = (const uint16_t *)x[i].scales;
3549
3550        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3551
3552        __m128i sumi1_0 = _mm_setzero_si128();
3553        __m128i sumi1_1 = _mm_setzero_si128();
3554        __m128i sumi2_0 = _mm_setzero_si128();
3555        __m128i sumi2_1 = _mm_setzero_si128();
3556        for (int ib = 0; ib < QK_K/32; ib += 2) {
3557            const __m128i q1b_1_0 = _mm_set_epi64x(
3558                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
3559            const __m128i q1b_1_1 = _mm_set_epi64x(
3560                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
3561            const __m128i q1b_2_0 = _mm_set_epi64x(
3562                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
3563            const __m128i q1b_2_1 = _mm_set_epi64x(
3564                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
3565            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3566            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3567            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3568            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3569
3570            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3571            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3572            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3573            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3574
3575            const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3576                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3577            const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3578                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3579            const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3580                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3581            const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3582                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3583
3584            const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
3585            const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
3586            const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
3587            const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
3588
3589            __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
3590            __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
3591            __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
3592            __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
3593
3594            scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
3595            scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
3596            scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
3597            scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
3598            const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
3599            const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
3600            const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
3601            const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
3602            const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
3603            const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
3604            const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
3605            const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
3606
3607            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3608            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3609            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
3610            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
3611
3612            qs += 8; qh += 4;
3613        }
3614
3615        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));
3616
3617        accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
3618        accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
3619    }
3620
3621    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
3622
3623#else
3624    UNUSED(x);
3625    UNUSED(y);
3626    UNUSED(nb);
3627    UNUSED(scale);
3628    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3629#endif
3630}
3631
3632void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3633    assert(nrc == 1);
3634    UNUSED(nrc);
3635    UNUSED(bx);
3636    UNUSED(by);
3637    UNUSED(bs);
3638    assert(n % QK4_NL == 0);
3639    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
3640
3641    const block_iq4_nl * GGML_RESTRICT x = vx;
3642    const block_q8_0   * GGML_RESTRICT y = vy;
3643
3644    const int nb = n / QK4_NL;
3645
3646    int ib = 0;
3647    float sumf = 0;
3648
3649#if defined __AVX2__
3650
3651    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3652    const __m128i m4b  = _mm_set1_epi8(0x0f);
3653    const __m256i mone = _mm256_set1_epi16(1);
3654
3655    __m256 accum1 = _mm256_setzero_ps();
3656    __m256 accum2 = _mm256_setzero_ps();
3657    for (; ib + 1 < nb; ib += 2) {
3658        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
3659        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
3660        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
3661        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
3662        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
3663                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
3664        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
3665                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
3666        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
3667        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
3668        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
3669        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
3670        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
3671                _mm256_cvtepi32_ps(p_1), accum1);
3672        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
3673                _mm256_cvtepi32_ps(p_2), accum2);
3674    }
3675
3676    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
3677
3678#elif defined __AVX__
3679    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3680    const __m128i m4b  = _mm_set1_epi8(0x0f);
3681
3682    __m256 accum = _mm256_setzero_ps();
3683    for (; ib + 1 < nb; ib += 2) {
3684        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
3685        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
3686        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
3687        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
3688        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
3689        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
3690
3691        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
3692        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
3693        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
3694        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
3695
3696        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
3697        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
3698        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
3699    }
3700
3701    sumf = hsum_float_8(accum);
3702
3703#endif
3704    for (; ib < nb; ++ib) {
3705        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
3706        int sumi1 = 0, sumi2 = 0;
3707        for (int j = 0; j < QK4_NL/2; ++j) {
3708            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
3709            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
3710        }
3711        sumf += d * (sumi1 + sumi2);
3712    }
3713    *s = sumf;
3714}
3715
3716void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3717    assert(nrc == 1);
3718    UNUSED(nrc);
3719    UNUSED(bx);
3720    UNUSED(by);
3721    UNUSED(bs);
3722    assert(n % QK_K == 0);
3723
3724    const block_iq4_xs * GGML_RESTRICT x = vx;
3725    const block_q8_K   * GGML_RESTRICT y = vy;
3726
3727    const int nb = n / QK_K;
3728
3729#if defined __AVX2__
3730
3731    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3732    const __m128i m4b  = _mm_set1_epi8(0x0f);
3733
3734    __m256 accum = _mm256_setzero_ps();
3735    for (int ibl = 0; ibl < nb; ++ibl) {
3736        const uint8_t * qs = x[ibl].qs;
3737        const int8_t  * q8 = y[ibl].qs;
3738        uint16_t sh = x[ibl].scales_h;
3739        __m256i sumi1 = _mm256_setzero_si256();
3740        __m256i sumi2 = _mm256_setzero_si256();
3741        for (int ib = 0; ib < QK_K/32; ib += 2) {
3742            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
3743            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
3744            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3745            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3746            const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
3747                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
3748            const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
3749                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
3750            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
3751            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
3752            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
3753            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
3754            sh >>= 4;
3755            const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
3756            const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
3757            sumi1 = _mm256_add_epi32(p_1, sumi1);
3758            sumi2 = _mm256_add_epi32(p_2, sumi2);
3759        }
3760        accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
3761                _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
3762    }
3763
3764    *s = hsum_float_8(accum);
3765
3766#elif defined __AVX__
3767    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3768    const __m128i m4b  = _mm_set1_epi8(0x0f);
3769
3770    __m256 accum = _mm256_setzero_ps();
3771    for (int ibl = 0; ibl < nb; ++ibl) {
3772        const uint8_t * qs = x[ibl].qs;
3773        const int8_t  * q8 = y[ibl].qs;
3774        uint16_t sh = x[ibl].scales_h;
3775        __m128i sumi1_0 = _mm_setzero_si128();
3776        __m128i sumi1_1 = _mm_setzero_si128();
3777        __m128i sumi2_0 = _mm_setzero_si128();
3778        __m128i sumi2_1 = _mm_setzero_si128();
3779        for (int ib = 0; ib < QK_K/32; ib += 2) {
3780            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
3781            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
3782            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3783            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3784            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3785            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3786            const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
3787            const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
3788            const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
3789            const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
3790            const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
3791            const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
3792            const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
3793            const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
3794            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
3795            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
3796            sh >>= 4;
3797            const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
3798            const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
3799            const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
3800            const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
3801            sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
3802            sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
3803            sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
3804            sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
3805        }
3806        __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
3807        __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
3808        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
3809                _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
3810    }
3811
3812    *s = hsum_float_8(accum);
3813
3814#else
3815    UNUSED(x);
3816    UNUSED(y);
3817    UNUSED(nb);
3818    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3819#endif
3820}