1#define GGML_COMMON_IMPL_C
2#include "ggml-common.h"
3#include "ggml-quants.h"
4#include "ggml-impl.h"
5#include "ggml-cpu.h"
6#include "simd-mappings.h"
7
8#include "../../quants.h"
9#include "../../ggml-cpu-impl.h"
10
11#include <math.h>
12#include <string.h>
13#include <assert.h>
14#include <stdlib.h> // for qsort
15#include <stdio.h> // for GGML_ASSERT
16
17#define GROUP_MAX_EPS 1e-15f
18#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19#define GROUP_MAX_EPS_IQ2_S 1e-8f
20#define GROUP_MAX_EPS_IQ1_M 1e-7f
21#define GROUP_MAX_EPS_IQ1_S 1e-12f
22
23#define UNUSED GGML_UNUSED
24
25// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
26#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
27
28#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
29// multiply int8_t, add results pairwise twice
30static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
31 // Get absolute values of x vectors
32 const __m128i ax = _mm_sign_epi8(x, x);
33 // Sign the values of the y vectors
34 const __m128i sy = _mm_sign_epi8(y, x);
35 // Perform multiplication and create 16-bit values
36 const __m128i dot = _mm_maddubs_epi16(ax, sy);
37 const __m128i ones = _mm_set1_epi16(1);
38 return _mm_madd_epi16(ones, dot);
39}
40
41#if __AVX__ || __AVX2__ || __AVX512F__
42// horizontally add 8 floats
43static inline float hsum_float_8(const __m256 x) {
44 __m128 res = _mm256_extractf128_ps(x, 1);
45 res = _mm_add_ps(res, _mm256_castps256_ps128(x));
46 res = _mm_add_ps(res, _mm_movehl_ps(res, res));
47 res = _mm_add_ss(res, _mm_movehdup_ps(res));
48 return _mm_cvtss_f32(res);
49}
50
51// horizontally add 8 int32_t
52static inline int hsum_i32_8(const __m256i a) {
53 const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
54 const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
55 const __m128i sum64 = _mm_add_epi32(hi64, sum128);
56 const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
57 return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
58}
59
60// horizontally add 4 int32_t
61static inline int hsum_i32_4(const __m128i a) {
62 const __m128i hi64 = _mm_unpackhi_epi64(a, a);
63 const __m128i sum64 = _mm_add_epi32(hi64, a);
64 const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
65 return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
66}
67
68#if defined(__AVX2__) || defined(__AVX512F__)
69static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
70 const __m256i ax = _mm256_sign_epi8(x, x);
71 const __m256i sy = _mm256_sign_epi8(y, x);
72 return _mm256_maddubs_epi16(ax, sy);
73}
74
75// spread 32 bits to 32 bytes { 0x00, 0xFF }
76static inline __m256i bytes_from_bits_32(const uint8_t * x) {
77 uint32_t x32;
78 memcpy(&x32, x, sizeof(uint32_t));
79 const __m256i shuf_mask = _mm256_set_epi64x(
80 0x0303030303030303, 0x0202020202020202,
81 0x0101010101010101, 0x0000000000000000);
82 __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
83 const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
84 bytes = _mm256_or_si256(bytes, bit_mask);
85 return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
86}
87
88// Unpack 32 4-bit fields into 32 bytes
89// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
90static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
91{
92 const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
93 const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
94 const __m256i lowMask = _mm256_set1_epi8( 0xF );
95 return _mm256_and_si256(lowMask, bytes);
96}
97
98// add int16_t pairwise and return as float vector
99static inline __m256 sum_i16_pairs_float(const __m256i x) {
100 const __m256i ones = _mm256_set1_epi16(1);
101 const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
102 return _mm256_cvtepi32_ps(summed_pairs);
103}
104
105static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
106#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
107 const __m256i zero = _mm256_setzero_si256();
108 const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
109 return _mm256_cvtepi32_ps(summed_pairs);
110#elif defined(__AVXVNNI__)
111 const __m256i zero = _mm256_setzero_si256();
112 const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
113 return _mm256_cvtepi32_ps(summed_pairs);
114#else
115 // Perform multiplication and create 16-bit values
116 const __m256i dot = _mm256_maddubs_epi16(ax, sy);
117 return sum_i16_pairs_float(dot);
118#endif
119}
120
121// multiply int8_t, add results pairwise twice and return as float vector
122static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
123#if __AVXVNNIINT8__
124 const __m256i zero = _mm256_setzero_si256();
125 const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
126 return _mm256_cvtepi32_ps(summed_pairs);
127#else
128 // Get absolute values of x vectors
129 const __m256i ax = _mm256_sign_epi8(x, x);
130 // Sign the values of the y vectors
131 const __m256i sy = _mm256_sign_epi8(y, x);
132 return mul_sum_us8_pairs_float(ax, sy);
133#endif
134}
135
136static inline __m128i packNibbles( __m256i bytes )
137{
138 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
139#if __AVX512F__
140 const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
141 bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
142 return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
143#else
144 const __m256i lowByte = _mm256_set1_epi16( 0xFF );
145 __m256i high = _mm256_andnot_si256( lowByte, bytes );
146 __m256i low = _mm256_and_si256( lowByte, bytes );
147 high = _mm256_srli_epi16( high, 4 );
148 bytes = _mm256_or_si256( low, high );
149
150 // Compress uint16_t lanes into bytes
151 __m128i r0 = _mm256_castsi256_si128( bytes );
152 __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
153 return _mm_packus_epi16( r0, r1 );
154#endif
155}
156#elif defined(__AVX__)
157static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
158{
159 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
160 const __m128i lowByte = _mm_set1_epi16( 0xFF );
161 __m128i high = _mm_andnot_si128( lowByte, bytes1 );
162 __m128i low = _mm_and_si128( lowByte, bytes1 );
163 high = _mm_srli_epi16( high, 4 );
164 bytes1 = _mm_or_si128( low, high );
165 high = _mm_andnot_si128( lowByte, bytes2 );
166 low = _mm_and_si128( lowByte, bytes2 );
167 high = _mm_srli_epi16( high, 4 );
168 bytes2 = _mm_or_si128( low, high );
169
170 return _mm_packus_epi16( bytes1, bytes2);
171}
172
173static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
174 const __m128i ax = _mm_sign_epi8(x, x);
175 const __m128i sy = _mm_sign_epi8(y, x);
176 return _mm_maddubs_epi16(ax, sy);
177}
178
179// spread 32 bits to 32 bytes { 0x00, 0xFF }
180static inline __m256i bytes_from_bits_32(const uint8_t * x) {
181 uint32_t x32;
182 memcpy(&x32, x, sizeof(uint32_t));
183 const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
184 const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
185 __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
186 __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
187 const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
188 bytesl = _mm_or_si128(bytesl, bit_mask);
189 bytesh = _mm_or_si128(bytesh, bit_mask);
190 bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
191 bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
192 return MM256_SET_M128I(bytesh, bytesl);
193}
194
195// Unpack 32 4-bit fields into 32 bytes
196// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
197static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
198{
199 // Load 16 bytes from memory
200 __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
201 __m128i tmph = _mm_srli_epi16(tmpl, 4);
202 const __m128i lowMask = _mm_set1_epi8(0xF);
203 tmpl = _mm_and_si128(lowMask, tmpl);
204 tmph = _mm_and_si128(lowMask, tmph);
205 return MM256_SET_M128I(tmph, tmpl);
206}
207
208// add int16_t pairwise and return as float vector
209static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
210 const __m128i ones = _mm_set1_epi16(1);
211 const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
212 const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
213 const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
214 return _mm256_cvtepi32_ps(summed_pairs);
215}
216
217static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
218 const __m128i axl = _mm256_castsi256_si128(ax);
219 const __m128i axh = _mm256_extractf128_si256(ax, 1);
220 const __m128i syl = _mm256_castsi256_si128(sy);
221 const __m128i syh = _mm256_extractf128_si256(sy, 1);
222 // Perform multiplication and create 16-bit values
223 const __m128i dotl = _mm_maddubs_epi16(axl, syl);
224 const __m128i doth = _mm_maddubs_epi16(axh, syh);
225 return sum_i16_pairs_float(doth, dotl);
226}
227
228// multiply int8_t, add results pairwise twice and return as float vector
229static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
230 const __m128i xl = _mm256_castsi256_si128(x);
231 const __m128i xh = _mm256_extractf128_si256(x, 1);
232 const __m128i yl = _mm256_castsi256_si128(y);
233 const __m128i yh = _mm256_extractf128_si256(y, 1);
234 // Get absolute values of x vectors
235 const __m128i axl = _mm_sign_epi8(xl, xl);
236 const __m128i axh = _mm_sign_epi8(xh, xh);
237 // Sign the values of the y vectors
238 const __m128i syl = _mm_sign_epi8(yl, xl);
239 const __m128i syh = _mm_sign_epi8(yh, xh);
240 // Perform multiplication and create 16-bit values
241 const __m128i dotl = _mm_maddubs_epi16(axl, syl);
242 const __m128i doth = _mm_maddubs_epi16(axh, syh);
243 return sum_i16_pairs_float(doth, dotl);
244}
245
246// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
247static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
248 const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
249 const __m128i mone = _mm_set1_epi16(1);
250
251 const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
252 const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
253 const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
254 const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
255 const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
256 const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
257 const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
258 const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
259 const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
260 const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
261 return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
262}
263
264// quad fp16 delta calculation
265static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
266 // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C
267 return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
268 _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
269}
270
271static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) {
272 return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
273 _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
274}
275#endif
276#elif defined(__SSSE3__)
277// horizontally add 4x4 floats
278static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
279 __m128 res_0 =_mm_hadd_ps(a, b);
280 __m128 res_1 =_mm_hadd_ps(c, d);
281 __m128 res =_mm_hadd_ps(res_0, res_1);
282 res =_mm_hadd_ps(res, res);
283 res =_mm_hadd_ps(res, res);
284
285 return _mm_cvtss_f32(res);
286}
287#endif // __AVX__ || __AVX2__ || __AVX512F__
288#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
289
290void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
291 assert(QK8_0 == 32);
292 assert(k % QK8_0 == 0);
293 const int nb = k / QK8_0;
294
295 block_q8_0 * GGML_RESTRICT y = vy;
296
297#if defined(__AVX2__) || defined(__AVX__)
298 for (int i = 0; i < nb; i++) {
299 // Load elements into 4 AVX vectors
300 __m256 v0 = _mm256_loadu_ps( x );
301 __m256 v1 = _mm256_loadu_ps( x + 8 );
302 __m256 v2 = _mm256_loadu_ps( x + 16 );
303 __m256 v3 = _mm256_loadu_ps( x + 24 );
304 x += 32;
305
306 // Compute max(abs(e)) for the block
307 const __m256 signBit = _mm256_set1_ps( -0.0f );
308 __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
309 maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
310 maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
311 maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
312
313 __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
314 max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
315 max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
316 const float maxScalar = _mm_cvtss_f32( max4 );
317
318 // Quantize these floats
319 const float d = maxScalar / 127.f;
320 y[i].d = GGML_CPU_FP32_TO_FP16(d);
321 const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
322 const __m256 mul = _mm256_set1_ps( id );
323
324 // Apply the multiplier
325 v0 = _mm256_mul_ps( v0, mul );
326 v1 = _mm256_mul_ps( v1, mul );
327 v2 = _mm256_mul_ps( v2, mul );
328 v3 = _mm256_mul_ps( v3, mul );
329
330 // Round to nearest integer
331 v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
332 v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
333 v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
334 v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
335
336 // Convert floats to integers
337 __m256i i0 = _mm256_cvtps_epi32( v0 );
338 __m256i i1 = _mm256_cvtps_epi32( v1 );
339 __m256i i2 = _mm256_cvtps_epi32( v2 );
340 __m256i i3 = _mm256_cvtps_epi32( v3 );
341
342#if defined(__AVX2__)
343 // Convert int32 to int16
344 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
345 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
346 // Convert int16 to int8
347 i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
348
349 // We got our precious signed bytes, but the order is now wrong
350 // These AVX2 pack instructions process 16-byte pieces independently
351 // The following instruction is fixing the order
352 const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
353 i0 = _mm256_permutevar8x32_epi32( i0, perm );
354
355 _mm256_storeu_si256((__m256i *)y[i].qs, i0);
356#else
357 // Since we don't have in AVX some necessary functions,
358 // we split the registers in half and call AVX2 analogs from SSE
359 __m128i ni0 = _mm256_castsi256_si128( i0 );
360 __m128i ni1 = _mm256_extractf128_si256( i0, 1);
361 __m128i ni2 = _mm256_castsi256_si128( i1 );
362 __m128i ni3 = _mm256_extractf128_si256( i1, 1);
363 __m128i ni4 = _mm256_castsi256_si128( i2 );
364 __m128i ni5 = _mm256_extractf128_si256( i2, 1);
365 __m128i ni6 = _mm256_castsi256_si128( i3 );
366 __m128i ni7 = _mm256_extractf128_si256( i3, 1);
367
368 // Convert int32 to int16
369 ni0 = _mm_packs_epi32( ni0, ni1 );
370 ni2 = _mm_packs_epi32( ni2, ni3 );
371 ni4 = _mm_packs_epi32( ni4, ni5 );
372 ni6 = _mm_packs_epi32( ni6, ni7 );
373 // Convert int16 to int8
374 ni0 = _mm_packs_epi16( ni0, ni2 );
375 ni4 = _mm_packs_epi16( ni4, ni6 );
376
377 _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
378 _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
379#endif
380 }
381#else
382 GGML_UNUSED(nb);
383 // scalar
384 quantize_row_q8_0_ref(x, y, k);
385#endif
386}
387
388void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
389 assert(k % QK8_1 == 0);
390 const int nb = k / QK8_1;
391
392 block_q8_1 * GGML_RESTRICT y = vy;
393#if defined(__AVX2__) || defined(__AVX__)
394 for (int i = 0; i < nb; i++) {
395 // Load elements into 4 AVX vectors
396 __m256 v0 = _mm256_loadu_ps( x );
397 __m256 v1 = _mm256_loadu_ps( x + 8 );
398 __m256 v2 = _mm256_loadu_ps( x + 16 );
399 __m256 v3 = _mm256_loadu_ps( x + 24 );
400 x += 32;
401
402 // Compute max(abs(e)) for the block
403 const __m256 signBit = _mm256_set1_ps( -0.0f );
404 __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
405 maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
406 maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
407 maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
408
409 __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
410 max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
411 max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
412 const float max_scalar = _mm_cvtss_f32( max4 );
413
414 // Quantize these floats
415 const float d = max_scalar / 127.f;
416 y[i].d = GGML_CPU_FP32_TO_FP16(d);
417 const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
418 const __m256 mul = _mm256_set1_ps( id );
419
420 // Apply the multiplier
421 v0 = _mm256_mul_ps( v0, mul );
422 v1 = _mm256_mul_ps( v1, mul );
423 v2 = _mm256_mul_ps( v2, mul );
424 v3 = _mm256_mul_ps( v3, mul );
425
426 // Round to nearest integer
427 v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
428 v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
429 v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
430 v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
431
432 // Convert floats to integers
433 __m256i i0 = _mm256_cvtps_epi32( v0 );
434 __m256i i1 = _mm256_cvtps_epi32( v1 );
435 __m256i i2 = _mm256_cvtps_epi32( v2 );
436 __m256i i3 = _mm256_cvtps_epi32( v3 );
437
438#if defined(__AVX2__)
439 // Compute the sum of the quants and set y[i].s
440 y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
441
442 // Convert int32 to int16
443 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
444 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
445 // Convert int16 to int8
446 i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
447
448 // We got our precious signed bytes, but the order is now wrong
449 // These AVX2 pack instructions process 16-byte pieces independently
450 // The following instruction is fixing the order
451 const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
452 i0 = _mm256_permutevar8x32_epi32( i0, perm );
453
454 _mm256_storeu_si256((__m256i *)y[i].qs, i0);
455#else
456 // Since we don't have in AVX some necessary functions,
457 // we split the registers in half and call AVX2 analogs from SSE
458 __m128i ni0 = _mm256_castsi256_si128( i0 );
459 __m128i ni1 = _mm256_extractf128_si256( i0, 1);
460 __m128i ni2 = _mm256_castsi256_si128( i1 );
461 __m128i ni3 = _mm256_extractf128_si256( i1, 1);
462 __m128i ni4 = _mm256_castsi256_si128( i2 );
463 __m128i ni5 = _mm256_extractf128_si256( i2, 1);
464 __m128i ni6 = _mm256_castsi256_si128( i3 );
465 __m128i ni7 = _mm256_extractf128_si256( i3, 1);
466
467 // Compute the sum of the quants and set y[i].s
468 const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
469 const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
470 y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
471
472 // Convert int32 to int16
473 ni0 = _mm_packs_epi32( ni0, ni1 );
474 ni2 = _mm_packs_epi32( ni2, ni3 );
475 ni4 = _mm_packs_epi32( ni4, ni5 );
476 ni6 = _mm_packs_epi32( ni6, ni7 );
477 // Convert int16 to int8
478 ni0 = _mm_packs_epi16( ni0, ni2 );
479 ni4 = _mm_packs_epi16( ni4, ni6 );
480
481 _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
482 _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
483#endif
484 }
485#else
486 GGML_UNUSED(nb);
487 // scalar
488 quantize_row_q8_1_ref(x, y, k);
489#endif
490}
491
492// placeholder implementation for Apple targets
493void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
494 quantize_row_q8_K_ref(x, y, k);
495}
496
497//===================================== Dot products =================================
498
499//
500// Helper functions
501//
502
503#if __AVX__ || __AVX2__ || __AVX512F__
504
505// shuffles to pick the required scales in dot products
506static inline __m256i get_scale_shuffle_q3k(int i) {
507 static const uint8_t k_shuffle[128] = {
508 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
509 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
510 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
511 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
512 };
513 return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
514}
515static inline __m256i get_scale_shuffle_k4(int i) {
516 static const uint8_t k_shuffle[256] = {
517 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
518 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
519 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
520 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
521 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
522 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
523 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
524 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
525 };
526 return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
527}
528static inline __m128i get_scale_shuffle(int i) {
529 static const uint8_t k_shuffle[128] = {
530 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
531 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
532 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
533 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
534 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
535 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
536 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
537 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
538 };
539 return _mm_loadu_si128((const __m128i*)k_shuffle + i);
540}
541#endif
542
543void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
544 const int qk = QK8_0;
545 const int nb = n / qk;
546
547 assert(n % qk == 0);
548 assert(nrc == 1);
549 UNUSED(nrc);
550 UNUSED(bx);
551 UNUSED(by);
552 UNUSED(bs);
553
554 const block_q4_0 * GGML_RESTRICT x = vx;
555 const block_q8_0 * GGML_RESTRICT y = vy;
556
557 int ib = 0;
558 float sumf = 0;
559
560#if defined(__AVX2__)
561 // Initialize accumulator with zeros
562 __m256 acc = _mm256_setzero_ps();
563
564 // Main loop
565 for (; ib < nb; ++ib) {
566 /* Compute combined scale for the block */
567 const __m256 d = _mm256_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
568
569 __m256i qx = bytes_from_nibbles_32(x[ib].qs);
570
571 // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
572 const __m256i off = _mm256_set1_epi8( 8 );
573 qx = _mm256_sub_epi8( qx, off );
574
575 __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
576
577 const __m256 q = mul_sum_i8_pairs_float(qx, qy);
578
579 /* Multiply q with scale and accumulate */
580 acc = _mm256_fmadd_ps( d, q, acc );
581 }
582
583 sumf = hsum_float_8(acc);
584#elif defined(__AVX__)
585 __m256 accum = _mm256_setzero_ps();
586 for (; ib + 1 < nb; ib += 2) {
587 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
588 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
589 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
590 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
591 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
592 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
593
594 const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
595 const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
596 const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
597 const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
598
599 const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
600 const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
601 const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
602 const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
603 const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
604 const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
605 const __m256 p = sum_i16_pairs_float(p_2, p_1);
606
607 const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
608 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
609 }
610
611 sumf = hsum_float_8(accum);
612#elif defined(__SSSE3__)
613 // set constants
614 const __m128i lowMask = _mm_set1_epi8(0xF);
615 const __m128i off = _mm_set1_epi8(8);
616
617 // Initialize accumulator with zeros
618 __m128 acc_0 = _mm_setzero_ps();
619 __m128 acc_1 = _mm_setzero_ps();
620 __m128 acc_2 = _mm_setzero_ps();
621 __m128 acc_3 = _mm_setzero_ps();
622
623 for (; ib + 1 < nb; ib += 2) {
624 _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
625 _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
626
627 // Compute combined scale for the block 0 and 1
628 const __m128 d_0_1 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
629
630 const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
631
632 __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
633 __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
634 bx_0 = _mm_sub_epi8(bx_0, off);
635 const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
636
637 __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
638 __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
639 bx_1 = _mm_sub_epi8(bx_1, off);
640 const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
641
642 _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
643 _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
644
645 // Compute combined scale for the block 2 and 3
646 const __m128 d_2_3 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
647
648 const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
649
650 __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
651 __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
652 bx_2 = _mm_sub_epi8(bx_2, off);
653 const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
654
655 __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
656 __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
657 bx_3 = _mm_sub_epi8(bx_3, off);
658 const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
659
660 // Convert int32_t to float
661 __m128 p0 = _mm_cvtepi32_ps(i32_0);
662 __m128 p1 = _mm_cvtepi32_ps(i32_1);
663 __m128 p2 = _mm_cvtepi32_ps(i32_2);
664 __m128 p3 = _mm_cvtepi32_ps(i32_3);
665
666 // Apply the scale
667 __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
668 __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
669 __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
670 __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
671
672 // Acummulate
673 acc_0 = _mm_add_ps(p0_d, acc_0);
674 acc_1 = _mm_add_ps(p1_d, acc_1);
675 acc_2 = _mm_add_ps(p2_d, acc_2);
676 acc_3 = _mm_add_ps(p3_d, acc_3);
677 }
678
679 sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
680
681#endif
682 for (; ib < nb; ++ib) {
683 int sumi0 = 0;
684 int sumi1 = 0;
685
686 for (int j = 0; j < qk/2; ++j) {
687 const int v0 = (x[ib].qs[j] & 0x0F) - 8;
688 const int v1 = (x[ib].qs[j] >> 4) - 8;
689
690 sumi0 += (v0 * y[ib].qs[j]);
691 sumi1 += (v1 * y[ib].qs[j + qk/2]);
692 }
693
694 int sumi = sumi0 + sumi1;
695 sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
696 }
697
698 *s = sumf;
699}
700
701void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
702 const int qk = QK8_1;
703 const int nb = n / qk;
704
705 assert(n % qk == 0);
706 assert(nrc == 1);
707 UNUSED(nrc);
708 UNUSED(bx);
709 UNUSED(by);
710 UNUSED(bs);
711
712 const block_q4_1 * GGML_RESTRICT x = vx;
713 const block_q8_1 * GGML_RESTRICT y = vy;
714
715 int ib = 0;
716
717#if defined(__AVX2__) || defined(__AVX__)
718 // Initialize accumulator with zeros
719 __m256 acc = _mm256_setzero_ps();
720
721 float summs = 0;
722
723 // Main loop
724 for (; ib < nb; ++ib) {
725 const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
726 const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);
727
728 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
729
730 const __m256 d0v = _mm256_set1_ps( d0 );
731 const __m256 d1v = _mm256_set1_ps( d1 );
732
733 // Compute combined scales
734 const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
735
736 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
737 const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
738 const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
739
740 const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
741
742 // Accumulate d0*d1*x*y
743#if defined(__AVX2__)
744 acc = _mm256_fmadd_ps( d0d1, xy, acc );
745#else
746 acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
747#endif
748 }
749
750 *s = hsum_float_8(acc) + summs;
751#else
752 UNUSED(nb);
753 UNUSED(x);
754 UNUSED(y);
755 UNUSED(ib);
756 ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
757#endif
758}
759
760void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
761 assert(nrc == 1);
762 UNUSED(nrc);
763 UNUSED(bx);
764 UNUSED(by);
765 UNUSED(bs);
766 assert(n % QK_MXFP4 == 0);
767 static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
768
769 const block_mxfp4 * GGML_RESTRICT x = vx;
770 const block_q8_0 * GGML_RESTRICT y = vy;
771
772 const int nb = n / QK_MXFP4;
773
774 int ib = 0;
775 float sumf = 0;
776
777#if defined __AVX2__
778
779 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
780 const __m128i m4b = _mm_set1_epi8(0x0f);
781 const __m256i mone = _mm256_set1_epi16(1);
782
783 __m256 accum1 = _mm256_setzero_ps();
784 __m256 accum2 = _mm256_setzero_ps();
785
786 for (; ib + 1 < nb; ib += 2) {
787 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
788 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
789 const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
790 const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
791 const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
792 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
793 const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
794 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
795 const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
796 const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
797 const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
798 const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
799 const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e));
800 const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e));
801 accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1);
802 accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2);
803 }
804
805 sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
806
807#elif defined __AVX__
808 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
809 const __m128i m4b = _mm_set1_epi8(0x0f);
810
811 __m256 accum = _mm256_setzero_ps();
812 for (; ib + 1 < nb; ib += 2) {
813 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
814 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
815 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
816 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
817 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
818 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
819
820 const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
821 const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
822 const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
823 const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
824
825 const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
826 const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
827 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
828 }
829
830 sumf = hsum_float_8(accum);
831
832#endif
833 for (; ib < nb; ++ib) {
834 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e);
835 int sumi1 = 0;
836 int sumi2 = 0;
837 for (int j = 0; j < QK_MXFP4/2; ++j) {
838 sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
839 sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
840 }
841 sumf += d * (sumi1 + sumi2);
842 }
843 *s = sumf;
844}
845
846void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
847 const int qk = QK8_0;
848 const int nb = n / qk;
849
850 int ib = 0;
851
852 assert(n % qk == 0);
853 assert(qk == QK5_0);
854 assert(nrc == 1);
855 UNUSED(nrc);
856 UNUSED(bx);
857 UNUSED(by);
858 UNUSED(bs);
859
860 const block_q5_0 * GGML_RESTRICT x = vx;
861 const block_q8_0 * GGML_RESTRICT y = vy;
862
863#if defined(__AVX2__)
864 // Initialize accumulator with zeros
865 __m256 acc = _mm256_setzero_ps();
866
867 // Main loop
868 for (; ib < nb; ++ib) {
869 /* Compute combined scale for the block */
870 const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
871
872 __m256i qx = bytes_from_nibbles_32(x[ib].qs);
873 __m256i bxhi = bytes_from_bits_32(x[ib].qh);
874 bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
875 qx = _mm256_or_si256(qx, bxhi);
876
877 __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
878
879 const __m256 q = mul_sum_i8_pairs_float(qx, qy);
880
881 /* Multiply q with scale and accumulate */
882 acc = _mm256_fmadd_ps(d, q, acc);
883 }
884
885 *s = hsum_float_8(acc);
886#elif defined(__AVX__)
887 // Initialize accumulator with zeros
888 __m256 acc = _mm256_setzero_ps();
889 __m128i mask = _mm_set1_epi8((char)0xF0);
890
891 // Main loop
892 for (; ib < nb; ++ib) {
893 /* Compute combined scale for the block */
894 const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
895
896 __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
897 const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
898 __m128i bxhil = _mm256_castsi256_si128(bxhi);
899 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
900 bxhil = _mm_andnot_si128(bxhil, mask);
901 bxhih = _mm_andnot_si128(bxhih, mask);
902 __m128i bxl = _mm256_castsi256_si128(bx_0);
903 __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
904 bxl = _mm_or_si128(bxl, bxhil);
905 bxh = _mm_or_si128(bxh, bxhih);
906 bx_0 = MM256_SET_M128I(bxh, bxl);
907
908 const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
909
910 const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
911
912 /* Multiply q with scale and accumulate */
913 acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
914 }
915
916 *s = hsum_float_8(acc);
917#else
918 UNUSED(nb);
919 UNUSED(ib);
920 UNUSED(x);
921 UNUSED(y);
922 ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
923#endif
924}
925
926void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
927 const int qk = QK8_1;
928 const int nb = n / qk;
929
930 int ib = 0;
931
932 assert(n % qk == 0);
933 assert(qk == QK5_1);
934 assert(nrc == 1);
935 UNUSED(nrc);
936 UNUSED(bx);
937 UNUSED(by);
938 UNUSED(bs);
939
940 const block_q5_1 * GGML_RESTRICT x = vx;
941 const block_q8_1 * GGML_RESTRICT y = vy;
942
943#if defined(__AVX2__)
944 // Initialize accumulator with zeros
945 __m256 acc = _mm256_setzero_ps();
946
947 float summs = 0.0f;
948
949 // Main loop
950 for (; ib < nb; ++ib) {
951 const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
952
953 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
954
955 __m256i qx = bytes_from_nibbles_32(x[ib].qs);
956 __m256i bxhi = bytes_from_bits_32(x[ib].qh);
957 bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
958 qx = _mm256_or_si256(qx, bxhi);
959
960 const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));
961 const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
962
963 const __m256 q = mul_sum_us8_pairs_float(qx, qy);
964
965 acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
966 }
967
968 *s = hsum_float_8(acc) + summs;
969#elif defined(__AVX__)
970 // Initialize accumulator with zeros
971 __m256 acc = _mm256_setzero_ps();
972 __m128i mask = _mm_set1_epi8(0x10);
973
974 float summs = 0.0f;
975
976 // Main loop
977 for (; ib < nb; ++ib) {
978 const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
979
980 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
981
982 __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
983 const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
984 __m128i bxhil = _mm256_castsi256_si128(bxhi);
985 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
986 bxhil = _mm_and_si128(bxhil, mask);
987 bxhih = _mm_and_si128(bxhih, mask);
988 __m128i bxl = _mm256_castsi256_si128(bx_0);
989 __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
990 bxl = _mm_or_si128(bxl, bxhil);
991 bxh = _mm_or_si128(bxh, bxhih);
992 bx_0 = MM256_SET_M128I(bxh, bxl);
993
994 const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));
995 const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
996
997 const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
998
999 acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
1000 }
1001
1002 *s = hsum_float_8(acc) + summs;
1003#else
1004 UNUSED(nb);
1005 UNUSED(ib);
1006 UNUSED(x);
1007 UNUSED(y);
1008 ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
1009#endif
1010}
1011
1012void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1013 const int qk = QK8_0;
1014 const int nb = n / qk;
1015
1016 assert(n % qk == 0);
1017 assert(nrc == 1);
1018 UNUSED(nrc);
1019 UNUSED(bx);
1020 UNUSED(by);
1021 UNUSED(bs);
1022
1023 const block_q8_0 * GGML_RESTRICT x = vx;
1024 const block_q8_0 * GGML_RESTRICT y = vy;
1025
1026 int ib = 0;
1027 float sumf = 0;
1028
1029#if defined(__AVX2__)
1030 // Initialize accumulator with zeros
1031 __m256 acc = _mm256_setzero_ps();
1032
1033 // Main loop
1034 for (; ib < nb; ++ib) {
1035 // Compute combined scale for the block
1036 const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
1037 __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
1038 __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
1039
1040 const __m256 q = mul_sum_i8_pairs_float(qx, qy);
1041
1042 // Multiply q with scale and accumulate
1043 acc = _mm256_fmadd_ps( d, q, acc );
1044 }
1045
1046 sumf = hsum_float_8(acc);
1047#elif defined(__AVX__)
1048 __m256 accum = _mm256_setzero_ps();
1049
1050 for (; ib + 1 < nb; ib += 2) {
1051 const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
1052 const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
1053 const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
1054 const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
1055 const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
1056 const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
1057 const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
1058 const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
1059
1060 const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
1061 const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
1062 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
1063 }
1064
1065 sumf = hsum_float_8(accum);
1066#endif
1067 for (; ib < nb; ++ib) {
1068 int sumi = 0;
1069
1070 for (int j = 0; j < qk; j++) {
1071 sumi += x[ib].qs[j]*y[ib].qs[j];
1072 }
1073
1074 sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
1075 }
1076
1077 *s = sumf;
1078}
1079
1080void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1081 assert(nrc == 1);
1082 UNUSED(nrc);
1083 UNUSED(bx);
1084 UNUSED(by);
1085 UNUSED(bs);
1086
1087 const block_tq1_0 * GGML_RESTRICT x = vx;
1088 const block_q8_K * GGML_RESTRICT y = vy;
1089
1090 const int nb = n / QK_K;
1091
1092#if defined(__AVX2__)
1093 __m256 sumf = _mm256_setzero_ps();
1094
1095 for (int i = 0; i < nb; ++i) {
1096 // 16-bit sums
1097 __m256i sumi0 = _mm256_setzero_si256();
1098 __m256i sumi1 = _mm256_setzero_si256();
1099 __m256i sumi2 = _mm256_setzero_si256();
1100
1101 // first 32 bytes of 5 elements
1102 {
1103 __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
1104 // 8-bit multiplies with shifts, masks and adds
1105 __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
1106 __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
1107 __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
1108 __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
1109
1110 // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
1111
1112 // Cancel the +1 from avg so that it behaves like a halving add
1113 qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
1114 qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
1115 qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
1116 qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
1117 qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
1118 // Multiply by 3 and get the top 2 bits
1119 qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
1120 qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
1121 qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
1122 qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
1123 qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
1124 qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
1125 qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
1126 qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
1127 qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
1128 qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
1129
1130 const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
1131 const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
1132 const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
1133 const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
1134 const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
1135
1136 qx0 = _mm256_maddubs_epi16(qx0, qy0);
1137 qx1 = _mm256_maddubs_epi16(qx1, qy1);
1138 qx2 = _mm256_maddubs_epi16(qx2, qy2);
1139 qx3 = _mm256_maddubs_epi16(qx3, qy3);
1140 qx4 = _mm256_maddubs_epi16(qx4, qy4);
1141
1142 sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
1143 sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
1144 sumi2 = _mm256_add_epi16(sumi2, qx4);
1145 }
1146
1147 // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1148 {
1149 __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
1150 uint32_t qh;
1151 memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
1152 __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
1153 __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
1154 __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
1155 __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
1156 __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
1157 __m256i qx01 = MM256_SET_M128I(qx1, qx0);
1158 __m256i qx23 = MM256_SET_M128I(qx3, qx2);
1159
1160 // avx2 does not have 8-bit multiplies, so 16-bit it is.
1161 qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
1162 qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
1163 __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
1164
1165 __m256i qx45 = MM256_SET_M128I(qx5, qx4);
1166
1167 // Cancel the +1 from avg so that it behaves like a halving add
1168 qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
1169 qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
1170 qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
1171 // Multiply by 3 and get the top 2 bits
1172 qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
1173 qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
1174 qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
1175 qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
1176 qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
1177 qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
1178
1179 const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
1180 const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
1181 const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
1182
1183 qx01 = _mm256_maddubs_epi16(qx01, qy01);
1184 qx23 = _mm256_maddubs_epi16(qx23, qy23);
1185 qx45 = _mm256_maddubs_epi16(qx45, qy45);
1186
1187 sumi0 = _mm256_add_epi16(sumi0, qx01);
1188 sumi1 = _mm256_add_epi16(sumi1, qx23);
1189 sumi2 = _mm256_add_epi16(sumi2, qx45);
1190 }
1191
1192 const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
1193 const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));
1194
1195 sumi0 = _mm256_sub_epi16(sumi0, ysum);
1196 sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
1197 sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
1198
1199 sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
1200 }
1201
1202 *s = hsum_float_8(sumf);
1203
1204#else
1205 UNUSED(x);
1206 UNUSED(y);
1207 UNUSED(nb);
1208 ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1209#endif
1210}
1211
1212void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1213 assert(nrc == 1);
1214 UNUSED(nrc);
1215 UNUSED(bx);
1216 UNUSED(by);
1217 UNUSED(bs);
1218
1219 const block_tq2_0 * GGML_RESTRICT x = vx;
1220 const block_q8_K * GGML_RESTRICT y = vy;
1221
1222 const int nb = n / QK_K;
1223
1224#if defined(__AVX2__)
1225 __m256 sumf = _mm256_setzero_ps();
1226
1227 for (int i = 0; i < nb; ++i) {
1228 // 16-bit sums, because 256*127 still fits
1229 __m256i sumi0 = _mm256_setzero_si256();
1230 __m256i sumi1 = _mm256_setzero_si256();
1231
1232 for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1233 __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
1234 __m256i qx1 = _mm256_srli_epi16(qx0, 2);
1235 __m256i qx2 = _mm256_srli_epi16(qx0, 4);
1236 __m256i qx3 = _mm256_srli_epi16(qx0, 6);
1237
1238 // 0, 1, 2 (should not be 3)
1239 qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
1240 qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
1241 qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
1242 qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
1243
1244 const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
1245 const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
1246 const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
1247 const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
1248
1249 qx0 = _mm256_maddubs_epi16(qx0, qy0);
1250 qx1 = _mm256_maddubs_epi16(qx1, qy1);
1251 qx2 = _mm256_maddubs_epi16(qx2, qy2);
1252 qx3 = _mm256_maddubs_epi16(qx3, qy3);
1253
1254 sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
1255 sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
1256 }
1257
1258 const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
1259 const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));
1260
1261 sumi0 = _mm256_add_epi16(sumi0, sumi1);
1262 sumi0 = _mm256_sub_epi16(sumi0, ysum);
1263 sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
1264
1265 sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
1266 }
1267
1268 *s = hsum_float_8(sumf);
1269
1270#else
1271 UNUSED(x);
1272 UNUSED(y);
1273 UNUSED(nb);
1274 ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1275#endif
1276}
1277
1278void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1279 assert(nrc == 1);
1280 UNUSED(nrc);
1281 UNUSED(bx);
1282 UNUSED(by);
1283 UNUSED(bs);
1284
1285 const block_q2_K * GGML_RESTRICT x = vx;
1286 const block_q8_K * GGML_RESTRICT y = vy;
1287
1288 const int nb = n / QK_K;
1289
1290#if defined __AVX2__
1291
1292 const __m256i m3 = _mm256_set1_epi8(3);
1293 const __m128i m4 = _mm_set1_epi8(0xF);
1294
1295 __m256 acc = _mm256_setzero_ps();
1296
1297 for (int i = 0; i < nb; ++i) {
1298
1299 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1300 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1301
1302 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1303 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1304
1305 const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1306 const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
1307 const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1308 const __m256i mins = _mm256_cvtepi8_epi16(mins8);
1309 const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
1310
1311 acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
1312
1313 const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
1314 const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1315 const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1316 const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1317
1318 __m256i sumi = _mm256_setzero_si256();
1319
1320 for (int j = 0; j < QK_K/128; ++j) {
1321
1322 const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
1323
1324 const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1325 const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1326 const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1327 const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1328
1329 const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
1330 const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
1331 const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
1332 const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
1333
1334 __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
1335 __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
1336 __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
1337 __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
1338
1339 p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
1340 p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
1341 p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
1342 p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
1343
1344 p0 = _mm256_add_epi32(p0, p1);
1345 p2 = _mm256_add_epi32(p2, p3);
1346
1347 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
1348 }
1349
1350 acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1351
1352 }
1353
1354 *s = hsum_float_8(acc);
1355
1356#elif defined __AVX__
1357
1358 const __m128i m3 = _mm_set1_epi8(0x3);
1359 const __m128i m4 = _mm_set1_epi8(0xF);
1360 const __m128i m2 = _mm_set1_epi8(0x2);
1361
1362 __m256 acc = _mm256_setzero_ps();
1363
1364 for (int i = 0; i < nb; ++i) {
1365
1366 const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1367 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1368
1369 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1370 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1371
1372 // load mins and scales from block_q2_K.scales[QK_K/16]
1373 const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1374 const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
1375 const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1376 const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
1377 const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
1378
1379 // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
1380 const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
1381 const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1382
1383 // sumf += -dmin * summs in 32bits*8
1384 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
1385
1386 const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1387 const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
1388 const __m128i scales[2] = { scales_0, scales_1 };
1389
1390 __m128i sumi_0 = _mm_setzero_si128();
1391 __m128i sumi_1 = _mm_setzero_si128();
1392
1393 for (int j = 0; j < QK_K/128; ++j) {
1394
1395 // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
1396 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1397 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1398 const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1399 const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1400 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1401 const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1402 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1403 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1404
1405 // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
1406 __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1407 const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1408 const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1409 const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1410 const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1411 q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1412 const __m128i q2_1 = _mm_and_si128(q2bits, m3);
1413 const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1414 const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1415 const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1416
1417 // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
1418 __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
1419 __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
1420 __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
1421 __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
1422 __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
1423 __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
1424 __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
1425 __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
1426
1427 // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
1428 __m128i shuffle = _mm_set1_epi16(0x0100);
1429 p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
1430 shuffle = _mm_add_epi16(shuffle, m2);
1431 p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
1432 shuffle = _mm_add_epi16(shuffle, m2);
1433 p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
1434 shuffle = _mm_add_epi16(shuffle, m2);
1435 p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
1436 shuffle = _mm_add_epi16(shuffle, m2);
1437 p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
1438 shuffle = _mm_add_epi16(shuffle, m2);
1439 p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
1440 shuffle = _mm_add_epi16(shuffle, m2);
1441 p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
1442 shuffle = _mm_add_epi16(shuffle, m2);
1443 p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
1444
1445 p0 = _mm_add_epi32(p0, p1);
1446 p2 = _mm_add_epi32(p2, p3);
1447 p4 = _mm_add_epi32(p4, p5);
1448 p6 = _mm_add_epi32(p6, p7);
1449
1450 // isum in 32bits*4*2
1451 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
1452 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
1453 }
1454
1455 // sumf += dall * isum - dmin * summs in 32bits
1456 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1457 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1458 }
1459
1460 *s = hsum_float_8(acc);
1461
1462#else
1463 UNUSED(x);
1464 UNUSED(y);
1465 UNUSED(nb);
1466 ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1467#endif
1468}
1469
1470void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1471 assert(n % QK_K == 0);
1472 assert(nrc == 1);
1473 UNUSED(nrc);
1474 UNUSED(bx);
1475 UNUSED(by);
1476 UNUSED(bs);
1477
1478 const uint32_t kmask1 = 0x03030303;
1479 const uint32_t kmask2 = 0x0f0f0f0f;
1480
1481 const block_q3_K * GGML_RESTRICT x = vx;
1482 const block_q8_K * GGML_RESTRICT y = vy;
1483
1484 const int nb = n / QK_K;
1485
1486#if defined __AVX2__
1487
1488 const __m256i m3 = _mm256_set1_epi8(3);
1489 const __m256i mone = _mm256_set1_epi8(1);
1490 const __m128i m32 = _mm_set1_epi8(32);
1491
1492 __m256 acc = _mm256_setzero_ps();
1493
1494 uint32_t aux[3];
1495
1496 for (int i = 0; i < nb; ++i) {
1497
1498 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1499
1500 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1501 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1502
1503 // Set up scales
1504 memcpy(aux, x[i].scales, 12);
1505 __m128i scales128 = _mm_set_epi32(
1506 ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1507 ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1508 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1509 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1510 scales128 = _mm_sub_epi8(scales128, m32);
1511 const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
1512 const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1513 const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1514 const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1515
1516 // high bit
1517 const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
1518
1519 // integer accumulator
1520 __m256i sumi = _mm256_setzero_si256();
1521
1522 int bit = 0;
1523 int is = 0;
1524
1525 for (int j = 0; j < QK_K/128; ++j) {
1526 // load low 2 bits
1527 const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
1528
1529 // prepare low and high bits
1530 const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
1531 const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1532 ++bit;
1533
1534 const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
1535 const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1536 ++bit;
1537
1538 const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
1539 const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1540 ++bit;
1541
1542 const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
1543 const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1544 ++bit;
1545
1546 // load Q8 quants
1547 const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1548 const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1549 const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1550 const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1551
1552 // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1553 // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1554 // and 2 if the high bit was set)
1555 __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
1556 __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
1557 __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
1558 __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
1559
1560 __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
1561 __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
1562 __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
1563 __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
1564
1565 p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
1566 p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
1567 p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
1568 p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
1569
1570 // multiply with scales
1571 p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
1572 p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
1573 p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
1574 p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
1575
1576 // accumulate
1577 p16_0 = _mm256_add_epi32(p16_0, p16_1);
1578 p16_2 = _mm256_add_epi32(p16_2, p16_3);
1579 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
1580
1581 }
1582
1583 // multiply with block scale and accumulate
1584 acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1585
1586 }
1587
1588 *s = hsum_float_8(acc);
1589
1590#elif defined __AVX__
1591
1592 const __m128i m3 = _mm_set1_epi8(3);
1593 const __m128i mone = _mm_set1_epi8(1);
1594 const __m128i m32 = _mm_set1_epi8(32);
1595 const __m128i m2 = _mm_set1_epi8(2);
1596
1597 __m256 acc = _mm256_setzero_ps();
1598
1599 const uint32_t *aux;
1600
1601 for (int i = 0; i < nb; ++i) {
1602
1603 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1604
1605 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1606 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1607
1608 // Set up scales
1609 aux = (const uint32_t *)x[i].scales;
1610 __m128i scales128 = _mm_set_epi32(
1611 ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1612 ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1613 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1614 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1615 scales128 = _mm_sub_epi8(scales128, m32);
1616 const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
1617 const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
1618 const __m128i scales[2] = { scales_0, scales_1 };
1619
1620 // high bit *128*2 from block_q3_K.hmask[QK_K/8]
1621 const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
1622 const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
1623
1624 // integer accumulator
1625 __m128i sumi_0 = _mm_setzero_si128();
1626 __m128i sumi_1 = _mm_setzero_si128();
1627
1628 for (int j = 0; j < QK_K/128; ++j) {
1629 // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
1630 const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1631 const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1632
1633 // prepare low and high bits
1634 const int bit = j << 2;
1635
1636 const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
1637 const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
1638 const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
1639 const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
1640
1641 const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
1642 const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
1643 const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1644 const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1645
1646 const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
1647 const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
1648 const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1649 const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1650
1651 const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
1652 const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
1653 const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1654 const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1655
1656 // load Q8 quants from block_q8_K.qs[QK_K]
1657 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1658 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1659 const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1660 const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1661 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1662 const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1663 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1664 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1665
1666 // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1667 // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1668 // and 2 if the high bit was set)
1669 __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
1670 __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
1671 __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
1672 __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
1673 __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
1674 __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
1675 __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
1676 __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
1677
1678 __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
1679 __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
1680 __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
1681 __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
1682 __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
1683 __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
1684 __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
1685 __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
1686
1687 p16_0 = _mm_sub_epi16(p16_0, q8s_0);
1688 p16_1 = _mm_sub_epi16(p16_1, q8s_1);
1689 p16_2 = _mm_sub_epi16(p16_2, q8s_2);
1690 p16_3 = _mm_sub_epi16(p16_3, q8s_3);
1691 p16_4 = _mm_sub_epi16(p16_4, q8s_4);
1692 p16_5 = _mm_sub_epi16(p16_5, q8s_5);
1693 p16_6 = _mm_sub_epi16(p16_6, q8s_6);
1694 p16_7 = _mm_sub_epi16(p16_7, q8s_7);
1695
1696 // multiply with scales
1697 __m128i shuffle = _mm_set1_epi16(0x0100);
1698 p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
1699 shuffle = _mm_add_epi16(shuffle, m2);
1700 p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
1701 shuffle = _mm_add_epi16(shuffle, m2);
1702 p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
1703 shuffle = _mm_add_epi16(shuffle, m2);
1704 p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
1705 shuffle = _mm_add_epi16(shuffle, m2);
1706 p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
1707 shuffle = _mm_add_epi16(shuffle, m2);
1708 p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
1709 shuffle = _mm_add_epi16(shuffle, m2);
1710 p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
1711 shuffle = _mm_add_epi16(shuffle, m2);
1712 p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
1713
1714 // accumulate
1715 p16_0 = _mm_add_epi32(p16_0, p16_1);
1716 p16_2 = _mm_add_epi32(p16_2, p16_3);
1717 p16_4 = _mm_add_epi32(p16_4, p16_5);
1718 p16_6 = _mm_add_epi32(p16_6, p16_7);
1719 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
1720 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
1721
1722 }
1723
1724 // multiply with block scale and accumulate
1725 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1726 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
1727
1728 }
1729
1730 *s = hsum_float_8(acc);
1731
1732#else
1733 UNUSED(kmask1);
1734 UNUSED(kmask2);
1735 UNUSED(x);
1736 UNUSED(y);
1737 UNUSED(nb);
1738 ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1739#endif
1740}
1741
1742void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1743 assert(n % QK_K == 0);
1744 assert(nrc == 1);
1745 UNUSED(nrc);
1746 UNUSED(bx);
1747 UNUSED(by);
1748 UNUSED(bs);
1749
1750 const block_q4_K * GGML_RESTRICT x = vx;
1751 const block_q8_K * GGML_RESTRICT y = vy;
1752
1753 const int nb = n / QK_K;
1754
1755 static const uint32_t kmask1 = 0x3f3f3f3f;
1756 static const uint32_t kmask2 = 0x0f0f0f0f;
1757 static const uint32_t kmask3 = 0x03030303;
1758
1759 uint32_t utmp[4];
1760
1761#if defined __AVX2__
1762
1763 const __m256i m4 = _mm256_set1_epi8(0xF);
1764
1765 __m256 acc = _mm256_setzero_ps();
1766 __m128 acc_m = _mm_setzero_ps();
1767
1768 for (int i = 0; i < nb; ++i) {
1769
1770 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1771 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1772
1773 memcpy(utmp, x[i].scales, 12);
1774 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1775 const uint32_t uaux = utmp[1] & kmask1;
1776 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1777 utmp[2] = uaux;
1778 utmp[0] &= kmask1;
1779
1780 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1781 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1782
1783 const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1784
1785 const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1786 const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1787 const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1788 acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
1789
1790 const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1791 const __m256i scales = MM256_SET_M128I(sc128, sc128);
1792
1793 __m256i sumi = _mm256_setzero_si256();
1794
1795 for (int j = 0; j < QK_K/64; ++j) {
1796
1797 const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1798 const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1799
1800 const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
1801 const __m256i q4l = _mm256_and_si256(q4bits, m4);
1802 const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
1803
1804 const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1805 __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
1806 p16l = _mm256_madd_epi16(scale_l, p16l);
1807
1808 const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1809 __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
1810 p16h = _mm256_madd_epi16(scale_h, p16h);
1811 const __m256i sumj = _mm256_add_epi32(p16l, p16h);
1812
1813 sumi = _mm256_add_epi32(sumi, sumj);
1814 }
1815
1816 __m256 vd = _mm256_set1_ps(d);
1817 acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
1818
1819 }
1820
1821 acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1822 acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1823
1824 *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1825
1826#elif defined __AVX__
1827
1828 const __m128i m4 = _mm_set1_epi8(0xF);
1829 const __m128i m2 = _mm_set1_epi8(0x2);
1830
1831 __m256 acc = _mm256_setzero_ps();
1832 __m128 acc_m = _mm_setzero_ps();
1833
1834 for (int i = 0; i < nb; ++i) {
1835
1836 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1837 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1838
1839 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1840 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1841
1842 memcpy(utmp, x[i].scales, 12);
1843 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1844 const uint32_t uaux = utmp[1] & kmask1;
1845 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1846 utmp[2] = uaux;
1847 utmp[0] &= kmask1;
1848
1849 const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
1850 const __m128i scales = _mm_cvtepu8_epi16(utmps);
1851 const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
1852
1853 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
1854 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
1855 const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
1856 const __m128i prod = _mm_madd_epi16(mins, q8s);
1857 acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
1858
1859 __m128i sumi_0 = _mm_setzero_si128();
1860 __m128i sumi_1 = _mm_setzero_si128();
1861
1862 __m128i shuffle = _mm_set1_epi16(0x0100);
1863 for (int j = 0; j < QK_K/64; ++j) {
1864
1865 const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
1866 shuffle = _mm_add_epi16(shuffle, m2);
1867 const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
1868 shuffle = _mm_add_epi16(shuffle, m2);
1869
1870 __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1871 const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
1872 const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1873 q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1874 const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
1875 const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1876
1877 const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1878 __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
1879 p16l = _mm_madd_epi16(scale_l, p16l);
1880 sumi_0 = _mm_add_epi32(sumi_0, p16l);
1881 const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1882 p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
1883 p16l = _mm_madd_epi16(scale_l, p16l);
1884 sumi_1 = _mm_add_epi32(sumi_1, p16l);
1885
1886 const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1887 __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
1888 p16h = _mm_madd_epi16(scale_h, p16h);
1889 sumi_0 = _mm_add_epi32(sumi_0, p16h);
1890 const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1891 p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
1892 p16h = _mm_madd_epi16(scale_h, p16h);
1893 sumi_1 = _mm_add_epi32(sumi_1, p16h);
1894
1895 }
1896
1897 __m256 vd = _mm256_set1_ps(d);
1898 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1899 acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
1900
1901 }
1902
1903 acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1904 acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1905
1906 *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1907
1908#else
1909 UNUSED(x);
1910 UNUSED(y);
1911 UNUSED(nb);
1912 UNUSED(kmask1);
1913 UNUSED(kmask2);
1914 UNUSED(kmask3);
1915 UNUSED(utmp);
1916 ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1917#endif
1918}
1919
1920void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1921 assert(n % QK_K == 0);
1922 assert(nrc == 1);
1923 UNUSED(nrc);
1924 UNUSED(bx);
1925 UNUSED(by);
1926 UNUSED(bs);
1927
1928 const block_q5_K * GGML_RESTRICT x = vx;
1929 const block_q8_K * GGML_RESTRICT y = vy;
1930
1931 const int nb = n / QK_K;
1932
1933 static const uint32_t kmask1 = 0x3f3f3f3f;
1934 static const uint32_t kmask2 = 0x0f0f0f0f;
1935 static const uint32_t kmask3 = 0x03030303;
1936
1937 uint32_t utmp[4];
1938
1939#if defined __AVX2__
1940
1941 const __m256i m4 = _mm256_set1_epi8(0xF);
1942 const __m128i mzero = _mm_setzero_si128();
1943 const __m256i mone = _mm256_set1_epi8(1);
1944
1945 __m256 acc = _mm256_setzero_ps();
1946
1947 float summs = 0.f;
1948
1949 for (int i = 0; i < nb; ++i) {
1950 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1951 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1952
1953 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1954 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1955
1956 memcpy(utmp, x[i].scales, 12);
1957 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1958 const uint32_t uaux = utmp[1] & kmask1;
1959 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1960 utmp[2] = uaux;
1961 utmp[0] &= kmask1;
1962
1963 const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1964
1965 const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1966 const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1967 const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1968 const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
1969 summs += dmin * _mm_extract_epi32(hsum, 0);
1970
1971 const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1972 const __m256i scales = MM256_SET_M128I(sc128, sc128);
1973
1974 const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
1975 __m256i hmask = mone;
1976
1977 __m256i sumi = _mm256_setzero_si256();
1978
1979 int bit = 0;
1980
1981 for (int j = 0; j < QK_K/64; ++j) {
1982
1983 const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1984 const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1985
1986 const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
1987
1988 const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
1989 const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1990 const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
1991 hmask = _mm256_slli_epi16(hmask, 1);
1992
1993 const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
1994 const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1995 const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
1996 hmask = _mm256_slli_epi16(hmask, 1);
1997
1998 const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1999 const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2000
2001 __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
2002 __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
2003
2004 p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2005 p16_1 = _mm256_madd_epi16(scale_1, p16_1);
2006
2007 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2008
2009 }
2010
2011 __m256 vd = _mm256_set1_ps(d);
2012 acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
2013
2014 }
2015
2016 *s = hsum_float_8(acc) + summs;
2017
2018#elif defined __AVX__
2019
2020 const __m128i m4 = _mm_set1_epi8(0xF);
2021 const __m128i mzero = _mm_setzero_si128();
2022 const __m128i mone = _mm_set1_epi8(1);
2023 const __m128i m2 = _mm_set1_epi8(2);
2024
2025 __m256 acc = _mm256_setzero_ps();
2026
2027 float summs = 0.f;
2028
2029 for (int i = 0; i < nb; ++i) {
2030
2031 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2032 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2033
2034 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
2035 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2036
2037 memcpy(utmp, x[i].scales, 12);
2038 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2039 const uint32_t uaux = utmp[1] & kmask1;
2040 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2041 utmp[2] = uaux;
2042 utmp[0] &= kmask1;
2043
2044 const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
2045 const __m128i scales = _mm_cvtepu8_epi16(utmps);
2046 const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
2047
2048 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
2049 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
2050 const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
2051 const __m128i prod = _mm_madd_epi16(mins, q8s);
2052 const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
2053 summs += dmin * _mm_extract_epi32(hsum, 0);
2054
2055 const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
2056 const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
2057 __m128i hmask = mone;
2058
2059 __m128i sumi_0 = _mm_setzero_si128();
2060 __m128i sumi_1 = _mm_setzero_si128();
2061
2062 int bit = 0;
2063
2064 __m128i shuffle = _mm_set1_epi16(0x0100);
2065 for (int j = 0; j < QK_K/64; ++j) {
2066
2067 const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
2068 shuffle = _mm_add_epi16(shuffle, m2);
2069 const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
2070 shuffle = _mm_add_epi16(shuffle, m2);
2071
2072 const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2073 const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2074
2075 __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
2076 __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
2077 __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2078 __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2079 __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
2080 __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
2081 hmask = _mm_slli_epi16(hmask, 1);
2082
2083 __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2084 __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2085 __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
2086 __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
2087 p16_0 = _mm_madd_epi16(scale_0, p16_0);
2088 p16_1 = _mm_madd_epi16(scale_0, p16_1);
2089
2090 q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
2091 q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
2092 q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2093 q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2094 q5_0 = _mm_add_epi8(q5l_0, q5h_0);
2095 q5_1 = _mm_add_epi8(q5l_1, q5h_1);
2096 hmask = _mm_slli_epi16(hmask, 1);
2097
2098 q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2099 q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2100 __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
2101 __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
2102 p16_2 = _mm_madd_epi16(scale_1, p16_2);
2103 p16_3 = _mm_madd_epi16(scale_1, p16_3);
2104
2105 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2106 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2107
2108 }
2109
2110 __m256 vd = _mm256_set1_ps(d);
2111 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2112 acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2113
2114 }
2115
2116 *s = hsum_float_8(acc) + summs;
2117
2118#else
2119 UNUSED(x);
2120 UNUSED(y);
2121 UNUSED(nb);
2122 UNUSED(kmask1);
2123 UNUSED(kmask2);
2124 UNUSED(kmask3);
2125 UNUSED(utmp);
2126 ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2127#endif
2128}
2129
2130void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2131 assert(n % QK_K == 0);
2132 assert(nrc == 1);
2133 UNUSED(nrc);
2134 UNUSED(bx);
2135 UNUSED(by);
2136 UNUSED(bs);
2137
2138 const block_q6_K * GGML_RESTRICT x = vx;
2139 const block_q8_K * GGML_RESTRICT y = vy;
2140
2141 const int nb = n / QK_K;
2142
2143#if defined __AVX2__
2144
2145 const __m256i m4 = _mm256_set1_epi8(0xF);
2146 const __m256i m2 = _mm256_set1_epi8(3);
2147 const __m256i m32s = _mm256_set1_epi8(32);
2148
2149 __m256 acc = _mm256_setzero_ps();
2150
2151 for (int i = 0; i < nb; ++i) {
2152
2153 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2154
2155 const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2156 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2157 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2158
2159 const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2160
2161 __m256i sumi = _mm256_setzero_si256();
2162
2163 int is = 0;
2164
2165 for (int j = 0; j < QK_K/128; ++j) {
2166
2167 const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2168 const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2169 const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2170 const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2171 is += 4;
2172
2173 const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2174 const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2175 const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
2176
2177 const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
2178 const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
2179 const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
2180 const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
2181
2182 const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
2183 const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
2184 const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
2185 const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
2186
2187 const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2188 const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2189 const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2190 const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2191
2192 __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
2193 __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
2194 __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
2195 __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
2196
2197 __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
2198 __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
2199 __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
2200 __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
2201
2202 p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2203 p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2204 p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
2205 p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
2206
2207 p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
2208 p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
2209 p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
2210 p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
2211
2212 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2213 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
2214
2215 }
2216
2217 acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
2218 }
2219
2220 *s = hsum_float_8(acc);
2221
2222#elif defined __AVX__
2223
2224 const __m128i m3 = _mm_set1_epi8(3);
2225 const __m128i m15 = _mm_set1_epi8(15);
2226
2227 __m256 acc = _mm256_setzero_ps();
2228
2229 for (int i = 0; i < nb; ++i) {
2230
2231 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2232
2233 const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2234 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2235 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2236
2237 // handle the q6_k -32 offset separately using bsums
2238 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
2239 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
2240 const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2241 const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
2242 const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
2243 const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
2244 const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
2245
2246 __m128i sumi_0 = _mm_setzero_si128();
2247 __m128i sumi_1 = _mm_setzero_si128();
2248
2249 int is = 0;
2250
2251 for (int j = 0; j < QK_K/128; ++j) {
2252
2253 const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2254 const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2255
2256 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
2257 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
2258 const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
2259 const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
2260 const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
2261 const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
2262 const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
2263 const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
2264
2265 const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2266 const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2267 const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2268 const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2269
2270 const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
2271 const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
2272 const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
2273 const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
2274 const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
2275 const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
2276 const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
2277 const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
2278
2279 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2280 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2281 const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2282 const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2283 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2284 const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2285 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2286 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2287
2288 __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
2289 __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
2290 __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
2291 __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
2292 __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
2293 __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
2294 __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
2295 __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
2296
2297 const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2298 const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2299 const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2300 const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2301 is += 4;
2302
2303 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
2304 p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
2305 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
2306 p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
2307 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
2308 p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
2309 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
2310 p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
2311
2312 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2313 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2314 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
2315 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
2316
2317 }
2318
2319 sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
2320 sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
2321 const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2322 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
2323 }
2324
2325 *s = hsum_float_8(acc);
2326
2327#else
2328 UNUSED(x);
2329 UNUSED(y);
2330 UNUSED(nb);
2331 ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2332#endif
2333}
2334
2335#if defined (__AVX__) || defined (__AVX2__)
2336static const int8_t keven_signs_q2xs[1024] = {
2337 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
2338 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
2339 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
2340 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
2341 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
2342 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
2343 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
2344 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
2345 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
2346 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
2347 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
2348 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
2349 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
2350 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
2351 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
2352 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
2353 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
2354 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
2355 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
2356 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
2357 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
2358 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
2359 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
2360 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
2361 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
2362 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
2363 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
2364 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
2365 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
2366 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
2367 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
2368 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
2369};
2370#endif
2371
2372void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2373 assert(n % QK_K == 0);
2374 assert(nrc == 1);
2375 UNUSED(nrc);
2376 UNUSED(bx);
2377 UNUSED(by);
2378 UNUSED(bs);
2379
2380 const block_iq2_xxs * GGML_RESTRICT x = vx;
2381 const block_q8_K * GGML_RESTRICT y = vy;
2382
2383 const int nb = n / QK_K;
2384
2385#if defined(__AVX2__)
2386
2387 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2388
2389 uint32_t aux32[4];
2390 const uint8_t * aux8 = (const uint8_t *)aux32;
2391
2392 __m256 accumf = _mm256_setzero_ps();
2393 for (int i = 0; i < nb; ++i) {
2394 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2395 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2396 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2397 __m256i sumi1 = _mm256_setzero_si256();
2398 __m256i sumi2 = _mm256_setzero_si256();
2399 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2400 const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2401 const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2402 memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2403 const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2404 const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2405 const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
2406 signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
2407 const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
2408 signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
2409 const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
2410 const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
2411 const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
2412 const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
2413 const uint16_t ls1 = aux32[1] >> 28;
2414 const uint16_t ls2 = aux32[3] >> 28;
2415 const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
2416 const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
2417 sumi1 = _mm256_add_epi32(sumi1, p1);
2418 sumi2 = _mm256_add_epi32(sumi2, p2);
2419 }
2420
2421 accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2422
2423 }
2424
2425 *s = 0.125f * hsum_float_8(accumf);
2426
2427#elif defined(__AVX__)
2428 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2429
2430 uint32_t aux32[4];
2431 const uint8_t * aux8 = (const uint8_t *)aux32;
2432
2433 __m256 accumf = _mm256_setzero_ps();
2434 for (int i = 0; i < nb; ++i) {
2435 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2436 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2437 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2438 __m128i sumi1_0 = _mm_setzero_si128();
2439 __m128i sumi1_1 = _mm_setzero_si128();
2440 __m128i sumi2_0 = _mm_setzero_si128();
2441 __m128i sumi2_1 = _mm_setzero_si128();
2442 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2443 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2444 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2445 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2446 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2447 memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2448 const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2449 const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
2450 const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2451 const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
2452 const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
2453 const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
2454 const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
2455 const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
2456 const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
2457 const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
2458 const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
2459 const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
2460 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2461 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2462 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2463 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2464 const uint16_t ls1 = aux32[1] >> 28;
2465 const uint16_t ls2 = aux32[3] >> 28;
2466 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
2467 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
2468 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
2469 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
2470 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2471 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2472 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2473 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2474 }
2475
2476 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2477
2478 }
2479
2480 *s = 0.125f * hsum_float_8(accumf);
2481
2482#else
2483 UNUSED(x);
2484 UNUSED(y);
2485 UNUSED(nb);
2486 ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2487#endif
2488}
2489
2490void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2491 assert(n % QK_K == 0);
2492 assert(nrc == 1);
2493 UNUSED(nrc);
2494 UNUSED(bx);
2495 UNUSED(by);
2496 UNUSED(bs);
2497
2498 const block_iq2_xs * GGML_RESTRICT x = vx;
2499 const block_q8_K * GGML_RESTRICT y = vy;
2500
2501 const int nb = n / QK_K;
2502
2503#if defined(__AVX2__)
2504
2505 const __m256i mone = _mm256_set1_epi8(1);
2506 static const char block_sign_shuffle_mask_1[32] = {
2507 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2508 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2509 };
2510 static const char block_sign_shuffle_mask_2[32] = {
2511 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2512 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2513 };
2514 static const uint8_t bit_selector_mask_bytes[32] = {
2515 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2516 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2517 };
2518
2519 const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
2520 const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
2521 const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
2522
2523 static const uint8_t k_bit_helper[32] = {
2524 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2525 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2526 };
2527 const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
2528 const __m256i m511 = _mm256_set1_epi16(511);
2529 const __m128i m4 = _mm_set1_epi8(0xf);
2530 const __m128i m1 = _mm_set1_epi8(1);
2531
2532 uint64_t aux64;
2533
2534 // somewhat hacky, but gives a significant boost in performance
2535 __m256i aux_gindex;
2536 const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2537
2538 __m256 accumf = _mm256_setzero_ps();
2539 for (int i = 0; i < nb; ++i) {
2540 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2541 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2542 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2543
2544 memcpy(&aux64, x[i].scales, 8);
2545 __m128i stmp = _mm_set1_epi64x(aux64);
2546 stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2547 const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2548
2549 __m256i sumi1 = _mm256_setzero_si256();
2550 __m256i sumi2 = _mm256_setzero_si256();
2551 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2552
2553 const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
2554 aux_gindex = _mm256_and_si256(q2_data, m511);
2555
2556 const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
2557 const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
2558 const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
2559
2560 const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
2561 const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
2562
2563 const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2564 const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2565 const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2566 const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2567
2568 const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
2569 iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
2570 const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
2571 iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
2572 const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
2573 iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
2574 const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
2575 iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2576
2577 const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
2578 const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
2579 const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
2580 const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
2581
2582 __m256i signs;
2583 signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
2584 signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2585 const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
2586
2587 signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
2588 signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2589 const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
2590
2591 signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
2592 signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2593 const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
2594
2595 signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
2596 signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2597 const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
2598
2599 const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
2600 const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
2601 const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
2602 const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
2603
2604 const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
2605 const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
2606 const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
2607 const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
2608
2609 sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
2610 sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
2611 sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
2612 sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
2613 }
2614
2615 accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2616
2617 }
2618
2619 *s = 0.125f * hsum_float_8(accumf);
2620
2621#elif defined(__AVX__)
2622 const __m128i mone = _mm_set1_epi8(1);
2623 static const char block_sign_shuffle_mask_1[32] = {
2624 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2625 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2626 };
2627 static const char block_sign_shuffle_mask_2[32] = {
2628 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2629 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2630 };
2631 static const uint8_t bit_selector_mask_bytes[32] = {
2632 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2633 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2634 };
2635
2636 const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
2637 const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
2638 const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
2639 const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
2640 const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
2641 const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
2642
2643 static const uint8_t k_bit_helper[32] = {
2644 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2645 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2646 };
2647 const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
2648 const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
2649 const __m128i m511 = _mm_set1_epi16(511);
2650 const __m128i m4 = _mm_set1_epi8(0xf);
2651 const __m128i m1 = _mm_set1_epi8(1);
2652
2653 uint64_t aux64;
2654
2655 // somewhat hacky, but gives a significant boost in performance
2656 __m256i aux_gindex;
2657 const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2658
2659 __m256 accumf = _mm256_setzero_ps();
2660 for (int i = 0; i < nb; ++i) {
2661 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2662 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2663 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2664
2665 memcpy(&aux64, x[i].scales, 8);
2666 __m128i stmp = _mm_set1_epi64x(aux64);
2667 stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2668 const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2669
2670 __m128i sumi1_0 = _mm_setzero_si128();
2671 __m128i sumi1_1 = _mm_setzero_si128();
2672 __m128i sumi2_0 = _mm_setzero_si128();
2673 __m128i sumi2_1 = _mm_setzero_si128();
2674 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2675
2676 const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
2677 const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
2678 aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
2679
2680 const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
2681 const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
2682 const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
2683 const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
2684 const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
2685 const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
2686
2687 const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
2688 const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
2689 const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
2690 const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
2691
2692 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2693 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2694 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2695 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2696 const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2697 const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2698 const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2699 const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2700
2701 const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
2702 const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
2703 const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
2704 const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
2705 const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
2706 const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
2707 const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2708 const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
2709
2710 // AVX2 full_signs_1 is full_sign_bits_0 here
2711 // AVX2 full_signs_2 is full_sign_bits_1 here
2712 __m128i signs_0, signs_1;
2713 signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
2714 signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
2715 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2716 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2717 const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
2718 const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
2719
2720 signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
2721 signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
2722 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2723 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2724 const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
2725 const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
2726
2727 signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
2728 signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
2729 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2730 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2731 const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
2732 const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
2733
2734 signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
2735 signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
2736 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2737 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2738 const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
2739 const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
2740
2741 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2742 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2743 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2744 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2745 const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
2746 const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
2747 const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
2748 const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
2749
2750 __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
2751 const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
2752 const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2753 sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
2754 const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
2755 const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2756 sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
2757 const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
2758 const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2759 sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
2760 const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
2761 const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2762
2763 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
2764 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
2765 sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
2766 sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
2767 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
2768 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
2769 sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
2770 sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
2771 }
2772
2773 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2774
2775 }
2776
2777 *s = 0.125f * hsum_float_8(accumf);
2778
2779#else
2780 UNUSED(x);
2781 UNUSED(y);
2782 UNUSED(nb);
2783 ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2784#endif
2785}
2786
2787void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2788 assert(n % QK_K == 0);
2789 assert(nrc == 1);
2790 UNUSED(nrc);
2791 UNUSED(bx);
2792 UNUSED(by);
2793 UNUSED(bs);
2794
2795 const block_iq2_s * GGML_RESTRICT x = vx;
2796 const block_q8_K * GGML_RESTRICT y = vy;
2797
2798 const int nb = n / QK_K;
2799
2800#if defined(__AVX2__)
2801
2802 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
2803 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
2804 };
2805
2806 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2807 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2808 };
2809
2810 const __m128i m4 = _mm_set1_epi8(0xf);
2811 const __m128i m1 = _mm_set1_epi8(1);
2812
2813 const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
2814 const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
2815
2816 uint64_t aux64;
2817
2818 __m256 accumf = _mm256_setzero_ps();
2819 for (int i = 0; i < nb; ++i) {
2820 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2821 const uint8_t * GGML_RESTRICT qs = x[i].qs;
2822 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2823 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
2824 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2825
2826 memcpy(&aux64, x[i].scales, 8);
2827 const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
2828 const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
2829
2830 __m256i sumi1 = _mm256_setzero_si256();
2831 __m256i sumi2 = _mm256_setzero_si256();
2832 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2833 const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2834 const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2835 const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
2836 iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
2837 iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
2838 iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
2839 const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
2840 iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
2841 iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
2842 iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
2843 qs += 8;
2844
2845 __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
2846 aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
2847 const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
2848 const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
2849
2850 aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
2851 aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
2852 const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
2853 const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
2854
2855 signs += 4;
2856
2857 const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
2858 const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
2859
2860 const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));
2861 const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));
2862 sumi1 = _mm256_add_epi32(sumi1, p1);
2863 sumi2 = _mm256_add_epi32(sumi2, p2);
2864 }
2865
2866 accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2867
2868 }
2869
2870 *s = 0.125f * hsum_float_8(accumf);
2871
2872#elif defined(__AVX__)
2873 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
2874 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
2875 };
2876
2877 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2878 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2879 };
2880
2881 const __m128i m4 = _mm_set1_epi8(0xf);
2882 const __m128i m1 = _mm_set1_epi8(1);
2883
2884 const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
2885 const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
2886 const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
2887 const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
2888
2889 uint64_t aux64;
2890
2891 __m256 accumf = _mm256_setzero_ps();
2892 for (int i = 0; i < nb; ++i) {
2893 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2894 const uint8_t * GGML_RESTRICT qs = x[i].qs;
2895 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2896 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
2897 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2898
2899 memcpy(&aux64, x[i].scales, 8);
2900 const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
2901 const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
2902 const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
2903
2904 __m128i sumi1_0 = _mm_setzero_si128();
2905 __m128i sumi1_1 = _mm_setzero_si128();
2906 __m128i sumi2_0 = _mm_setzero_si128();
2907 __m128i sumi2_1 = _mm_setzero_si128();
2908 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2909 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2910 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2911 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2912 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2913 const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
2914 iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
2915 const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
2916 iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
2917 const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
2918 iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
2919 const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
2920 iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
2921 qs += 8;
2922
2923 __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
2924 __m128i aux128_1 = aux128_0;
2925 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
2926 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
2927 const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
2928 const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
2929 const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
2930 const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
2931
2932 aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
2933 aux128_1 = aux128_0;
2934 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
2935 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
2936 const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
2937 const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
2938 const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
2939 const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
2940
2941 signs += 4;
2942
2943 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2944 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2945 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2946 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2947
2948 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
2949 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
2950 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
2951 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
2952 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2953 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2954 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2955 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2956 }
2957
2958 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2959
2960 }
2961
2962 *s = 0.125f * hsum_float_8(accumf);
2963
2964#else
2965 UNUSED(x);
2966 UNUSED(y);
2967 UNUSED(nb);
2968 ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2969#endif
2970}
2971
2972void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2973 assert(n % QK_K == 0);
2974 assert(nrc == 1);
2975 UNUSED(nrc);
2976 UNUSED(bx);
2977 UNUSED(by);
2978 UNUSED(bs);
2979
2980 const block_iq3_xxs * GGML_RESTRICT x = vx;
2981 const block_q8_K * GGML_RESTRICT y = vy;
2982
2983 const int nb = n / QK_K;
2984
2985#if defined(__AVX2__)
2986
2987 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2988
2989 uint32_t aux32[2];
2990
2991 __m256 accumf = _mm256_setzero_ps();
2992 for (int i = 0; i < nb; ++i) {
2993 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2994 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
2995 const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
2996 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2997 __m256i sumi1 = _mm256_setzero_si256();
2998 __m256i sumi2 = _mm256_setzero_si256();
2999 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3000 const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3001 const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3002 const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
3003 iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3004 q3 += 8;
3005 const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
3006 iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3007 q3 += 8;
3008 memcpy(aux32, gas, 8); gas += 8;
3009 const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
3010 signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
3011 const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
3012 signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
3013 const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
3014 const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
3015 const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
3016 const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
3017 const uint16_t ls1 = aux32[0] >> 28;
3018 const uint16_t ls2 = aux32[1] >> 28;
3019 const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
3020 const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
3021 sumi1 = _mm256_add_epi32(sumi1, p1);
3022 sumi2 = _mm256_add_epi32(sumi2, p2);
3023 }
3024
3025 accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3026
3027 }
3028
3029 *s = 0.25f * hsum_float_8(accumf);
3030
3031#elif defined(__AVX__)
3032 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3033
3034 uint32_t aux32[2];
3035
3036 __m256 accumf = _mm256_setzero_ps();
3037 for (int i = 0; i < nb; ++i) {
3038 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3039 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
3040 const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
3041 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3042 __m128i sumi1_0 = _mm_setzero_si128();
3043 __m128i sumi1_1 = _mm_setzero_si128();
3044 __m128i sumi2_0 = _mm_setzero_si128();
3045 __m128i sumi2_1 = _mm_setzero_si128();
3046 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3047 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3048 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3049 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3050 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3051 const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3052 const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3053 q3 += 8;
3054 const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3055 const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3056 q3 += 8;
3057 memcpy(aux32, gas, 8); gas += 8;
3058 const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
3059 const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
3060 const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
3061 const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
3062 const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
3063 const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
3064 const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
3065 const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
3066 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3067 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3068 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3069 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3070 const uint16_t ls1 = aux32[0] >> 28;
3071 const uint16_t ls2 = aux32[1] >> 28;
3072 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3073 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3074 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3075 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3076 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3077 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3078 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3079 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3080 }
3081
3082 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3083
3084 }
3085
3086 *s = 0.25f * hsum_float_8(accumf);
3087
3088#else
3089 UNUSED(x);
3090 UNUSED(y);
3091 UNUSED(nb);
3092 ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3093#endif
3094}
3095
3096void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3097 assert(n % QK_K == 0);
3098 assert(nrc == 1);
3099 UNUSED(nrc);
3100 UNUSED(bx);
3101 UNUSED(by);
3102 UNUSED(bs);
3103
3104 const block_iq3_s * GGML_RESTRICT x = vx;
3105 const block_q8_K * GGML_RESTRICT y = vy;
3106
3107 const int nb = n / QK_K;
3108
3109#if defined(__AVX2__)
3110
3111 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3112 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3113 };
3114
3115 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3116 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3117 };
3118
3119 const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
3120 const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
3121
3122 const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
3123 const __m256i idx_mask = _mm256_set1_epi32(256);
3124
3125 typedef union {
3126 __m256i vec[2];
3127 uint32_t index[16];
3128 } index_t;
3129
3130 index_t idx;
3131
3132 __m256 accumf = _mm256_setzero_ps();
3133 for (int i = 0; i < nb; ++i) {
3134 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3135 const uint8_t * GGML_RESTRICT qs = x[i].qs;
3136 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3137 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3138 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3139 __m256i sumi1 = _mm256_setzero_si256();
3140 __m256i sumi2 = _mm256_setzero_si256();
3141 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3142 const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3143 const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3144 const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
3145 idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
3146 idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
3147 idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
3148 idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
3149 idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
3150 idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
3151
3152 // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
3153 //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
3154 //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
3155 const __m256i q2_1 = _mm256_set_epi32(
3156 iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
3157 iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
3158 );
3159 const __m256i q2_2 = _mm256_set_epi32(
3160 iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
3161 iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
3162 );
3163
3164 __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
3165 aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3166 const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
3167 const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
3168
3169 aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
3170 aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3171 const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
3172 const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
3173
3174 signs += 4;
3175
3176 const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
3177 const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
3178 const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3179 const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
3180 const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
3181 const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
3182 sumi1 = _mm256_add_epi32(sumi1, p1);
3183 sumi2 = _mm256_add_epi32(sumi2, p2);
3184 }
3185
3186 accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3187
3188 }
3189
3190 *s = hsum_float_8(accumf);
3191
3192#elif defined(__AVX__)
3193 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3194 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3195 };
3196
3197 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3198 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3199 };
3200
3201 const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
3202 const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
3203 const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
3204 const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
3205
3206 const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
3207 const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
3208 const __m128i idx_mask = _mm_set1_epi32(256);
3209
3210 typedef union {
3211 __m128i vec[4];
3212 uint32_t index[16];
3213 } index_t;
3214
3215 index_t idx;
3216
3217 __m256 accumf = _mm256_setzero_ps();
3218 for (int i = 0; i < nb; ++i) {
3219 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3220 const uint8_t * GGML_RESTRICT qs = x[i].qs;
3221 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3222 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3223 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3224 __m128i sumi1_0 = _mm_setzero_si128();
3225 __m128i sumi1_1 = _mm_setzero_si128();
3226 __m128i sumi2_0 = _mm_setzero_si128();
3227 __m128i sumi2_1 = _mm_setzero_si128();
3228 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3229 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3230 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3231 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3232 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3233 const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
3234 const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
3235 const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
3236 idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
3237 idx.vec[1] = idx.vec[0];
3238 idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
3239 idx.vec[3] = idx.vec[2];
3240
3241 idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
3242 idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
3243 idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
3244 idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
3245
3246 idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
3247 idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
3248 idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
3249 idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
3250
3251 const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
3252 const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
3253 const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
3254 const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
3255
3256 __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
3257 __m128i aux128_1 = aux128_0;
3258 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3259 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3260 const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3261 const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3262 const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
3263 const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
3264
3265 aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
3266 aux128_1 = aux128_0;
3267 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3268 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3269 const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3270 const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3271 const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
3272 const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
3273
3274 signs += 4;
3275
3276 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3277 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3278 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3279 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3280 const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3281 const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
3282 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3283 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3284 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3285 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3286 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3287 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3288 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3289 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3290 }
3291
3292 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3293
3294 }
3295
3296 *s = hsum_float_8(accumf);
3297
3298#else
3299 UNUSED(x);
3300 UNUSED(y);
3301 UNUSED(nb);
3302 ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3303#endif
3304}
3305
3306void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3307 assert(n % QK_K == 0);
3308 assert(nrc == 1);
3309 UNUSED(nrc);
3310 UNUSED(bx);
3311 UNUSED(by);
3312 UNUSED(bs);
3313
3314 const block_iq1_s * GGML_RESTRICT x = vx;
3315 const block_q8_K * GGML_RESTRICT y = vy;
3316
3317 const int nb = n / QK_K;
3318
3319#if defined __AVX2__
3320
3321 __m256 accum = _mm256_setzero_ps();
3322 float accum1 = 0;
3323 for (int i = 0; i < nb; ++i) {
3324
3325 const int8_t * q8 = y[i].qs;
3326 const uint8_t * qs = x[i].qs;
3327 const uint16_t * qh = x[i].qh;
3328
3329 __m256i sumi = _mm256_setzero_si256();
3330 int sumi1 = 0;
3331 for (int ib = 0; ib < QK_K/32; ib += 2) {
3332#ifdef __BMI2__
3333 const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL);
3334 const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL);
3335 const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3336 const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3337 const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
3338 const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
3339#else
3340 const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
3341 iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3342 const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
3343 iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3344#endif
3345 qs += 8;
3346 const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3347 const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3348
3349 const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
3350 const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
3351 const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3352 const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3353 const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
3354 const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
3355
3356 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
3357 sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3358 + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3359 }
3360
3361 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
3362 accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
3363 accum1 += d * sumi1;
3364
3365 }
3366
3367 *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3368
3369#elif defined __AVX__
3370 __m256 accum = _mm256_setzero_ps();
3371 float accum1 = 0;
3372 for (int i = 0; i < nb; ++i) {
3373
3374 const int8_t * q8 = y[i].qs;
3375 const uint8_t * qs = x[i].qs;
3376 const uint16_t * qh = x[i].qh;
3377
3378 __m128i sumi1_0 = _mm_setzero_si128();
3379 __m128i sumi1_1 = _mm_setzero_si128();
3380 int sumi1 = 0;
3381 for (int ib = 0; ib < QK_K/32; ib += 2) {
3382 const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3383 const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
3384 const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3385 const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
3386 qs += 8;
3387 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3388 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3389 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3390 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3391
3392 const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3393 const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3394 const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3395 const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3396 const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3397 const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3398 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
3399 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
3400 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
3401 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
3402
3403 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3404 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3405 sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3406 + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3407 }
3408
3409 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
3410 accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
3411 accum1 += d * sumi1;
3412
3413 }
3414
3415 *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3416
3417#else
3418 UNUSED(x);
3419 UNUSED(y);
3420 UNUSED(nb);
3421 ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3422#endif
3423}
3424
3425void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3426 assert(n % QK_K == 0);
3427 assert(nrc == 1);
3428 UNUSED(nrc);
3429 UNUSED(bx);
3430 UNUSED(by);
3431 UNUSED(bs);
3432
3433 const block_iq1_m * GGML_RESTRICT x = vx;
3434 const block_q8_K * GGML_RESTRICT y = vy;
3435
3436 const int nb = n / QK_K;
3437
3438 iq1m_scale_t scale;
3439
3440#if defined __AVX2__
3441
3442 const __m256i mask = _mm256_set1_epi16(0x7);
3443 const __m256i mone = _mm256_set1_epi16(1);
3444 const __m256i mone8 = _mm256_set1_epi8(1);
3445 const __m256i mtwo8 = _mm256_set1_epi8(2);
3446 // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
3447 const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
3448
3449 __m256 accum1 = _mm256_setzero_ps();
3450 __m256 accum2 = _mm256_setzero_ps();
3451 for (int i = 0; i < nb; ++i) {
3452
3453 const int8_t * q8 = y[i].qs;
3454 const uint8_t * qs = x[i].qs;
3455 const uint8_t * qh = x[i].qh;
3456 const uint16_t * sc = (const uint16_t *)x[i].scales;
3457
3458 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3459 // Extract 3-bit scales (16 values)
3460 __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
3461 scales = _mm256_srlv_epi64(scales, scales_shift);
3462 scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
3463
3464 // Indices to repeat each scale 8 times.
3465 __m256i scales_idx1 = _mm256_set1_epi16(0x0100);
3466 __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
3467
3468 __m256i sumi1 = _mm256_setzero_si256();
3469 __m256i sumi2 = _mm256_setzero_si256();
3470 for (int ib = 0; ib < QK_K/32; ib += 2) {
3471#ifdef __BMI2__
3472 const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL)
3473 | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL);
3474 const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL)
3475 | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL);
3476 const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3477 const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3478 const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
3479 const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
3480
3481 // Convert signs to bytes 0x81 (negative) or 0x01 (positive)
3482 const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL);
3483 const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign)));
3484 const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32)));
3485#else
3486 const __m256i q1b_1 = _mm256_set_epi64x(
3487 iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
3488 iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
3489 );
3490 const __m256i q1b_2 = _mm256_set_epi64x(
3491 iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
3492 iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
3493 );
3494
3495 const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3496 qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3497 qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3498 qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3499 const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3500 qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3501 qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3502 qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3503#endif
3504 const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3505 const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3506
3507 const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
3508 const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
3509 const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
3510 const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
3511
3512 __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
3513 __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
3514
3515 scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
3516 scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
3517
3518 const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
3519 const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
3520 const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
3521 const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
3522
3523 sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
3524 sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
3525
3526 qs += 8; qh += 4;
3527 }
3528
3529 const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));
3530
3531 accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
3532 accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
3533 }
3534
3535 *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
3536
3537#elif defined __AVX__
3538 const __m128i mask = _mm_set1_epi16(0x7);
3539 const __m128i mone = _mm_set1_epi16(1);
3540
3541 __m256 accum1 = _mm256_setzero_ps();
3542 __m256 accum2 = _mm256_setzero_ps();
3543 for (int i = 0; i < nb; ++i) {
3544
3545 const int8_t * q8 = y[i].qs;
3546 const uint8_t * qs = x[i].qs;
3547 const uint8_t * qh = x[i].qh;
3548 const uint16_t * sc = (const uint16_t *)x[i].scales;
3549
3550 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3551
3552 __m128i sumi1_0 = _mm_setzero_si128();
3553 __m128i sumi1_1 = _mm_setzero_si128();
3554 __m128i sumi2_0 = _mm_setzero_si128();
3555 __m128i sumi2_1 = _mm_setzero_si128();
3556 for (int ib = 0; ib < QK_K/32; ib += 2) {
3557 const __m128i q1b_1_0 = _mm_set_epi64x(
3558 iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
3559 const __m128i q1b_1_1 = _mm_set_epi64x(
3560 iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
3561 const __m128i q1b_2_0 = _mm_set_epi64x(
3562 iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
3563 const __m128i q1b_2_1 = _mm_set_epi64x(
3564 iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
3565 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3566 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3567 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3568 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3569
3570 const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3571 const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3572 const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3573 const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3574
3575 const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3576 qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3577 const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3578 qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3579 const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3580 qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3581 const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3582 qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3583
3584 const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
3585 const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
3586 const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
3587 const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
3588
3589 __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
3590 __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
3591 __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
3592 __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
3593
3594 scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
3595 scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
3596 scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
3597 scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
3598 const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
3599 const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
3600 const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
3601 const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
3602 const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
3603 const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
3604 const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
3605 const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
3606
3607 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3608 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3609 sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
3610 sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
3611
3612 qs += 8; qh += 4;
3613 }
3614
3615 const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));
3616
3617 accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
3618 accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
3619 }
3620
3621 *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
3622
3623#else
3624 UNUSED(x);
3625 UNUSED(y);
3626 UNUSED(nb);
3627 UNUSED(scale);
3628 ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3629#endif
3630}
3631
3632void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3633 assert(nrc == 1);
3634 UNUSED(nrc);
3635 UNUSED(bx);
3636 UNUSED(by);
3637 UNUSED(bs);
3638 assert(n % QK4_NL == 0);
3639 static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
3640
3641 const block_iq4_nl * GGML_RESTRICT x = vx;
3642 const block_q8_0 * GGML_RESTRICT y = vy;
3643
3644 const int nb = n / QK4_NL;
3645
3646 int ib = 0;
3647 float sumf = 0;
3648
3649#if defined __AVX2__
3650
3651 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3652 const __m128i m4b = _mm_set1_epi8(0x0f);
3653 const __m256i mone = _mm256_set1_epi16(1);
3654
3655 __m256 accum1 = _mm256_setzero_ps();
3656 __m256 accum2 = _mm256_setzero_ps();
3657 for (; ib + 1 < nb; ib += 2) {
3658 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
3659 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
3660 const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
3661 const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
3662 const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
3663 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
3664 const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
3665 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
3666 const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
3667 const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
3668 const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
3669 const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
3670 accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
3671 _mm256_cvtepi32_ps(p_1), accum1);
3672 accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
3673 _mm256_cvtepi32_ps(p_2), accum2);
3674 }
3675
3676 sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
3677
3678#elif defined __AVX__
3679 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3680 const __m128i m4b = _mm_set1_epi8(0x0f);
3681
3682 __m256 accum = _mm256_setzero_ps();
3683 for (; ib + 1 < nb; ib += 2) {
3684 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
3685 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
3686 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
3687 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
3688 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
3689 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
3690
3691 const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
3692 const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
3693 const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
3694 const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
3695
3696 const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
3697 const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
3698 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
3699 }
3700
3701 sumf = hsum_float_8(accum);
3702
3703#endif
3704 for (; ib < nb; ++ib) {
3705 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
3706 int sumi1 = 0, sumi2 = 0;
3707 for (int j = 0; j < QK4_NL/2; ++j) {
3708 sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
3709 sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
3710 }
3711 sumf += d * (sumi1 + sumi2);
3712 }
3713 *s = sumf;
3714}
3715
3716void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3717 assert(nrc == 1);
3718 UNUSED(nrc);
3719 UNUSED(bx);
3720 UNUSED(by);
3721 UNUSED(bs);
3722 assert(n % QK_K == 0);
3723
3724 const block_iq4_xs * GGML_RESTRICT x = vx;
3725 const block_q8_K * GGML_RESTRICT y = vy;
3726
3727 const int nb = n / QK_K;
3728
3729#if defined __AVX2__
3730
3731 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3732 const __m128i m4b = _mm_set1_epi8(0x0f);
3733
3734 __m256 accum = _mm256_setzero_ps();
3735 for (int ibl = 0; ibl < nb; ++ibl) {
3736 const uint8_t * qs = x[ibl].qs;
3737 const int8_t * q8 = y[ibl].qs;
3738 uint16_t sh = x[ibl].scales_h;
3739 __m256i sumi1 = _mm256_setzero_si256();
3740 __m256i sumi2 = _mm256_setzero_si256();
3741 for (int ib = 0; ib < QK_K/32; ib += 2) {
3742 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
3743 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
3744 const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3745 const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3746 const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
3747 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
3748 const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
3749 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
3750 const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
3751 const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
3752 const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
3753 const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
3754 sh >>= 4;
3755 const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
3756 const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
3757 sumi1 = _mm256_add_epi32(p_1, sumi1);
3758 sumi2 = _mm256_add_epi32(p_2, sumi2);
3759 }
3760 accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
3761 _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
3762 }
3763
3764 *s = hsum_float_8(accum);
3765
3766#elif defined __AVX__
3767 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3768 const __m128i m4b = _mm_set1_epi8(0x0f);
3769
3770 __m256 accum = _mm256_setzero_ps();
3771 for (int ibl = 0; ibl < nb; ++ibl) {
3772 const uint8_t * qs = x[ibl].qs;
3773 const int8_t * q8 = y[ibl].qs;
3774 uint16_t sh = x[ibl].scales_h;
3775 __m128i sumi1_0 = _mm_setzero_si128();
3776 __m128i sumi1_1 = _mm_setzero_si128();
3777 __m128i sumi2_0 = _mm_setzero_si128();
3778 __m128i sumi2_1 = _mm_setzero_si128();
3779 for (int ib = 0; ib < QK_K/32; ib += 2) {
3780 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
3781 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
3782 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3783 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3784 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3785 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3786 const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
3787 const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
3788 const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
3789 const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
3790 const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
3791 const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
3792 const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
3793 const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
3794 const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
3795 const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
3796 sh >>= 4;
3797 const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
3798 const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
3799 const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
3800 const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
3801 sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
3802 sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
3803 sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
3804 sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
3805 }
3806 __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
3807 __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
3808 accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
3809 _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
3810 }
3811
3812 *s = hsum_float_8(accum);
3813
3814#else
3815 UNUSED(x);
3816 UNUSED(y);
3817 UNUSED(nb);
3818 ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3819#endif
3820}