summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cpu/arch/loongarch
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/arch/loongarch')
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c2159
1 files changed, 2159 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c b/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c
new file mode 100644
index 0000000..f531e91
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c
@@ -0,0 +1,2159 @@
1#define GGML_COMMON_IMPL_C
2#include "ggml-common.h"
3#include "ggml-quants.h"
4#include "ggml-impl.h"
5#include "ggml-cpu.h"
6#include "simd-mappings.h"
7
8#include "../../quants.h"
9#include "../../ggml-cpu-impl.h"
10
11#include <math.h>
12#include <string.h>
13#include <assert.h>
14#include <float.h>
15#include <stdlib.h> // for qsort
16#include <stdio.h> // for GGML_ASSERT
17
18#define GROUP_MAX_EPS 1e-15f
19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20#define GROUP_MAX_EPS_IQ2_S 1e-8f
21#define GROUP_MAX_EPS_IQ1_M 1e-7f
22#define GROUP_MAX_EPS_IQ1_S 1e-12f
23
24#define UNUSED GGML_UNUSED
25
26#if defined(__loongarch_sx)
27
28static __m128i lsx_packs_w(__m128i a, __m128i b) {
29 __m128i tmp, tmp1;
30 tmp = __lsx_vsat_w(a, 15);
31 tmp1 = __lsx_vsat_w(b, 15);
32 return __lsx_vpickev_h(tmp1, tmp);
33}
34
35static __m128i lsx_packs_h(__m128i a, __m128i b) {
36 __m128i tmp, tmp1;
37 tmp = __lsx_vsat_h(a, 7);
38 tmp1 = __lsx_vsat_h(b, 7);
39 return __lsx_vpickev_b(tmp1, tmp);
40}
41
42static __m128i lsx_packus_h(__m128i a, __m128i b) {
43 __m128i tmp, tmp1;
44 tmp = __lsx_vsat_hu(a, 7);
45 tmp1 = __lsx_vsat_hu(b, 7);
46 return __lsx_vpickev_b(tmp1, tmp);
47}
48
49static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
50 __m128i tmp1, tmp2;
51 tmp1 = __lsx_vmulwev_h_b(a, b);
52 tmp2 = __lsx_vmulwod_h_b(a, b);
53 return __lsx_vsadd_h(tmp1, tmp2);
54}
55
56static __m128i lsx_madd_h(__m128i a, __m128i b) {
57 __m128i tmp1, tmp2;
58 tmp1 = __lsx_vmulwev_w_h(a, b);
59 tmp2 = __lsx_vmulwod_w_h(a, b);
60 return __lsx_vadd_w(tmp1, tmp2);
61}
62
63static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
64 v4i32 __ret = {d, c, b, a};
65 return (__m128i)__ret;
66}
67
68static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
69 __m128i mask_f, zero, tmp0, tmp2, mask;
70 int f = 0x8f;
71 mask_f = __lsx_vreplgr2vr_b(f);
72 zero = __lsx_vldi(0);
73 tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
74 tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
75 mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
76 tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
77 return __lsx_vshuf_b(a, zero, tmp2);
78}
79
80static __m128i lsx_hadd_h(__m128i a, __m128i b) {
81 __m128i tmp1 = __lsx_vpickev_h(b, a);
82 __m128i tmp2 = __lsx_vpickod_h(b, a);
83 return __lsx_vadd_h(tmp1, tmp2);
84}
85
86static __m128i lsx_hadd_w(__m128i a, __m128i b) {
87 __m128i tmp1 = __lsx_vpickev_w(b, a);
88 __m128i tmp2 = __lsx_vpickod_w(b, a);
89 return __lsx_vadd_w(tmp1, tmp2);
90}
91
92static __m128 lsx_hadd_s(__m128 a, __m128 b) {
93 __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
94 __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
95
96 return __lsx_vfadd_s(tmp1, tmp2);
97}
98
99static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
100 __m128 res_0 =lsx_hadd_s(a, b);
101 __m128 res_1 =lsx_hadd_s(c, d);
102 __m128 res =lsx_hadd_s(res_0, res_1);
103 res =lsx_hadd_s(res, res);
104 res =lsx_hadd_s(res, res);
105
106 return ((v4f32)res)[0];
107}
108
109// multiply int8_t, add results pairwise twice
110static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
111 // Get absolute values of x vectors
112 const __m128i ax = __lsx_vsigncov_b(x, x);
113 // Sign the values of the y vectors
114 const __m128i sy = __lsx_vsigncov_b(x, y);
115 // Perform multiplication and create 16-bit values
116 const __m128i dot = lsx_maddubs_h(ax, sy);
117 const __m128i ones = __lsx_vreplgr2vr_h(1);
118 return lsx_madd_h(ones, dot);
119}
120#endif
121
122#if defined(__loongarch_asx)
123
124#ifdef __clang__
125#define VREGS_PREFIX "$vr"
126#define XREGS_PREFIX "$xr"
127#else // GCC
128#define VREGS_PREFIX "$f"
129#define XREGS_PREFIX "$f"
130#endif
131#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
132// Convert __m128i to __m256i
133static inline __m256i ____m256i(__m128i in) {
134 __m256i out = __lasx_xvldi(0);
135 __asm__ volatile (
136 ".irp i," __ALL_REGS "\n\t"
137 " .ifc %[out], " XREGS_PREFIX"\\i \n\t"
138 " .irp j," __ALL_REGS "\n\t"
139 " .ifc %[in], " VREGS_PREFIX "\\j \n\t"
140 " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
141 " .endif \n\t"
142 " .endr \n\t"
143 " .endif \n\t"
144 ".endr \n\t"
145 : [out] "+f" (out) : [in] "f" (in)
146 );
147 return out;
148}
149// Convert two __m128i to __m256i
150static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
151 __m256i out;
152 __asm__ volatile (
153 ".irp i," __ALL_REGS "\n\t"
154 " .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
155 " .irp j," __ALL_REGS "\n\t"
156 " .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
157 " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
158 " .endif \n\t"
159 " .endr \n\t"
160 " .endif \n\t"
161 ".endr \n\t"
162 ".ifnc %[out], %[hi] \n\t"
163 ".irp i," __ALL_REGS "\n\t"
164 " .ifc %[out], " XREGS_PREFIX "\\i \n\t"
165 " .irp j," __ALL_REGS "\n\t"
166 " .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
167 " xvori.b $xr\\i, $xr\\j, 0 \n\t"
168 " .endif \n\t"
169 " .endr \n\t"
170 " .endif \n\t"
171 ".endr \n\t"
172 ".endif \n\t"
173 : [out] "=f" (out), [hi] "+f" (inhi)
174 : [lo] "f" (inlo)
175 );
176 return out;
177}
178// Convert __m256i low part to __m128i
179static inline __m128i lasx_extracti128_lo(__m256i in) {
180 __m128i out;
181 __asm__ volatile (
182 ".ifnc %[out], %[in] \n\t"
183 ".irp i," __ALL_REGS "\n\t"
184 " .ifc %[out], " VREGS_PREFIX "\\i \n\t"
185 " .irp j," __ALL_REGS "\n\t"
186 " .ifc %[in], " XREGS_PREFIX "\\j \n\t"
187 " vori.b $vr\\i, $vr\\j, 0 \n\t"
188 " .endif \n\t"
189 " .endr \n\t"
190 " .endif \n\t"
191 ".endr \n\t"
192 ".endif \n\t"
193 : [out] "=f" (out) : [in] "f" (in)
194 );
195 return out;
196}
197// Convert __m256i high part to __m128i
198static inline __m128i lasx_extracti128_hi(__m256i in) {
199 __m128i out;
200 __asm__ volatile (
201 ".irp i," __ALL_REGS "\n\t"
202 " .ifc %[out], " VREGS_PREFIX "\\i \n\t"
203 " .irp j," __ALL_REGS "\n\t"
204 " .ifc %[in], " XREGS_PREFIX "\\j \n\t"
205 " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
206 " .endif \n\t"
207 " .endr \n\t"
208 " .endif \n\t"
209 ".endr \n\t"
210 : [out] "=f" (out) : [in] "f" (in)
211 );
212 return out;
213}
214
215static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
216 v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
217 return (__m256i)__ret;
218}
219
220static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
221 v4i64 __ret = {d, c, b, a};
222 return (__m256i)__ret;
223}
224
225static __m256i lasx_insertf128( __m128i x, __m128i y) {
226 return lasx_set_q(x, y);
227}
228
229static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
230 __m256i mask_f, zero, tmp0, tmp2, mask;
231 int f = 0x8f;
232 mask_f = __lasx_xvreplgr2vr_b(f);
233 zero = __lasx_xvldi(0);
234 tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
235 tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
236 mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
237 tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
238 return __lasx_xvshuf_b(a, zero, tmp2);
239}
240
241static __m256i lasx_extu8_16(__m128i a) {
242 return __lasx_vext2xv_hu_bu(____m256i(a));
243}
244
245static __m256i lasx_ext8_16(__m128i a) {
246 return __lasx_vext2xv_h_b(____m256i(a));
247}
248
249static __m256i lasx_ext16_32(__m128i a) {
250 return __lasx_vext2xv_w_h(____m256i(a));
251}
252
253static __m128i lasx_extracti128( __m256i a, int pos) {
254 __m128i ret;
255 if( pos == 0)
256 {
257 ret = lasx_extracti128_lo(a);
258 } else {
259 ret = lasx_extracti128_hi(a);
260 }
261 return ret;
262}
263
264static __m128 lasx_extractf128( __m256 a, int pos) {
265 __m128 ret;
266 if( pos == 0)
267 {
268 ret = (__m128)lasx_extracti128_lo((__m256i)a);
269 } else {
270 ret = (__m128)lasx_extracti128_hi((__m256i)a);
271 }
272 return ret;
273}
274
275static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
276 __m256i tmp1, tmp2;
277 tmp1 = __lasx_xvmulwev_h_b(a, b);
278 tmp2 = __lasx_xvmulwod_h_b(a, b);
279 return __lasx_xvsadd_h(tmp1, tmp2);
280}
281
282static __m256i lasx_madd_h(__m256i a, __m256i b) {
283 __m256i tmp1, tmp2;
284 tmp1 = __lasx_xvmulwev_w_h(a, b);
285 tmp2 = __lasx_xvmulwod_w_h(a, b);
286 return __lasx_xvadd_w(tmp1, tmp2);
287}
288
289static __m256i lasx_packs_w(__m256i a, __m256i b) {
290 __m256i tmp, tmp1;
291 tmp = __lasx_xvsat_w(a, 15);
292 tmp1 = __lasx_xvsat_w(b, 15);
293 return __lasx_xvpickev_h(tmp1, tmp);
294}
295
296static __m256i lasx_packs_h(__m256i a, __m256i b) {
297 __m256i tmp, tmp1;
298 tmp = __lasx_xvsat_h(a, 7);
299 tmp1 = __lasx_xvsat_h(b, 7);
300 return __lasx_xvpickev_b(tmp1, tmp);
301}
302
303static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
304 __m256i tmp1, tmp2;
305 tmp1 = __lasx_xvmulwev_h_b(a, b);
306 tmp2 = __lasx_xvmulwod_h_b(a, b);
307 return __lasx_xvadd_h(tmp1, tmp2);
308}
309
310static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
311 switch (b) {
312 case 0: return __lasx_xvrepl128vei_h(a, 0);
313 case 1: return __lasx_xvrepl128vei_h(a, 1);
314 case 2: return __lasx_xvrepl128vei_h(a, 2);
315 case 3: return __lasx_xvrepl128vei_h(a, 3);
316 case 4: return __lasx_xvrepl128vei_h(a, 4);
317 case 5: return __lasx_xvrepl128vei_h(a, 5);
318 case 6: return __lasx_xvrepl128vei_h(a, 6);
319 case 7: return __lasx_xvrepl128vei_h(a, 7);
320 default: __builtin_unreachable();
321 }
322}
323
324static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
325 switch (b) {
326 case 0: return __lasx_xvandi_b(a, 1 << 0);
327 case 1: return __lasx_xvandi_b(a, 1 << 1);
328 case 2: return __lasx_xvandi_b(a, 1 << 2);
329 case 3: return __lasx_xvandi_b(a, 1 << 3);
330 case 4: return __lasx_xvandi_b(a, 1 << 4);
331 case 5: return __lasx_xvandi_b(a, 1 << 5);
332 case 6: return __lasx_xvandi_b(a, 1 << 6);
333 case 7: return __lasx_xvandi_b(a, 1 << 7);
334 default: __builtin_unreachable();
335 }
336}
337
338// horizontally add 8 floats
339static inline float hsum_float_8(const __m256 x) {
340 __m128 res = lasx_extractf128(x, 1);
341 res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
342 res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
343 res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
344 return ((v4f32)res)[0];
345}
346
347// horizontally add 8 int32_t
348static inline int hsum_i32_8(const __m256i a) {
349
350 __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
351 __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
352
353 __m128i tmp1_128 = lasx_extracti128_lo(tmp1);
354 __m128i tmp2_128 = lasx_extracti128_lo(tmp2);
355
356 __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
357
358 __m128i ev = __lsx_vpickev_w(sum128, sum128);
359 __m128i od = __lsx_vpickod_w(sum128, sum128);
360 __m128i sum64 = __lsx_vadd_w(ev, od);
361
362 int sum64_1, sum64_2;
363 sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
364 sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
365
366 return sum64_1 + sum64_2;
367}
368
369// horizontally add 4 int32_t
370static inline int hsum_i32_4(const __m128i a) {
371 __m128i ev = __lsx_vpickev_w(a, a);
372 __m128i od = __lsx_vpickod_w(a, a);
373 __m128i sum64 = __lsx_vadd_w(ev, od);
374
375 int sum64_1, sum64_2;
376 sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
377 sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
378
379 return sum64_1 + sum64_2;
380}
381
382// spread 32 bits to 32 bytes { 0x00, 0xFF }
383static inline __m256i bytes_from_bits_32(const uint8_t * x) {
384
385 uint32_t x32;
386 memcpy(&x32, x, sizeof(uint32_t));
387 const __m256i shuf_mask = lasx_set_d(
388 0x0303030303030303, 0x0202020202020202,
389 0x0101010101010101, 0x0000000000000000);
390
391 __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
392 const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
393 bytes = __lasx_xvor_v(bytes, bit_mask);
394 return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
395}
396
397// Unpack 32 4-bit fields into 32 bytes
398// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
399static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
400 const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
401 __m128i hi = __lsx_vsrli_h(lo, 4);
402 return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
403}
404
405// add int16_t pairwise and return as float vector
406static inline __m256 sum_i16_pairs_float(const __m256i x) {
407 __m256i v = __lasx_xvpackod_h(x, x);
408 __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
409 return __lasx_xvffint_s_w(summed_pairs);
410}
411
412static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
413 // Perform multiplication and create 16-bit values
414 const __m256i dot = lasx_maddubs_h(ax, sy);
415 return sum_i16_pairs_float(dot);
416}
417
418// multiply int8_t, add results pairwise twice and return as float vector
419static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
420 const __m256i dot = lasx_madd_h_b(x, y);
421 return sum_i16_pairs_float(dot);
422}
423
424static inline __m128i packNibbles( __m256i bytes ) {
425 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
426 const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
427 __m256i high = __lasx_xvandn_v(lowByte, bytes);
428 __m256i low = __lasx_xvand_v(lowByte, bytes);
429 high = __lasx_xvsrli_h(high, 4);
430 bytes = __lasx_xvor_v(low, high);
431 // Compress uint16_t lanes into bytes
432 __m128i *r0 = (__m128i *)&bytes;
433 __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
434 __m128i *r1 = (__m128i *)&tmp_h128;
435
436 __m128i zero = __lsx_vldi(0);
437 __m128i tmp, tmp2, tmp3;
438
439 tmp = __lsx_vmax_h(zero, *r0);
440 tmp2 = __lsx_vsat_hu(tmp, 7);
441
442 tmp = __lsx_vmax_h(zero, *r1);
443 tmp3 = __lsx_vsat_hu(tmp, 7);
444 return __lsx_vpickev_b(tmp3, tmp2);
445}
446#endif //__loongarch_asx
447
448void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
449 assert(QK8_0 == 32);
450 assert(k % QK8_0 == 0);
451 const int nb = k / QK8_0;
452
453 block_q8_0 * GGML_RESTRICT y = vy;
454
455#if defined(__loongarch_asx)
456 for (int i = 0; i < nb; i++) {
457 __m256 v0 = (__m256)__lasx_xvld( x , 0);
458 __m256 v1 = (__m256)__lasx_xvld( x , 32);
459 __m256 v2 = (__m256)__lasx_xvld( x , 64);
460 __m256 v3 = (__m256)__lasx_xvld( x , 96);
461 x += 32;
462
463 // Compute max(abs(e)) for the block
464 const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
465 __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
466 max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
467 max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
468 max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
469
470 __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
471 max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
472 __m128 tmp = max4;
473 max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
474 const float max_scalar = ((v4f32)max4)[0];
475
476 // Quantize these floats
477 const float d = max_scalar / 127.f;
478 y[i].d = GGML_CPU_FP32_TO_FP16(d);
479 const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
480 const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
481
482 // Apply the multiplier
483 v0 = __lasx_xvfmul_s( v0, mul );
484 v1 = __lasx_xvfmul_s( v1, mul );
485 v2 = __lasx_xvfmul_s( v2, mul );
486 v3 = __lasx_xvfmul_s( v3, mul );
487
488 // Round to nearest integer
489 __m256i i0 = __lasx_xvftintrne_w_s( v0 );
490 __m256i i1 = __lasx_xvftintrne_w_s( v1 );
491 __m256i i2 = __lasx_xvftintrne_w_s( v2 );
492 __m256i i3 = __lasx_xvftintrne_w_s( v3 );
493
494 __m128i ni0 = lasx_extracti128( i0, 0 );
495 __m128i ni1 = lasx_extracti128( i0, 1);
496 __m128i ni2 = lasx_extracti128( i1, 0);
497 __m128i ni3 = lasx_extracti128( i1, 1);
498 __m128i ni4 = lasx_extracti128( i2, 0);
499 __m128i ni5 = lasx_extracti128( i2, 1);
500 __m128i ni6 = lasx_extracti128( i3, 0);
501 __m128i ni7 = lasx_extracti128( i3, 1);
502
503 // Convert int32 to int16
504 ni0 = lsx_packs_w( ni0, ni1 );
505 ni2 = lsx_packs_w( ni2, ni3 );
506 ni4 = lsx_packs_w( ni4, ni5 );
507 ni6 = lsx_packs_w( ni6, ni7 );
508 // Convert int16 to int8
509 ni0 = lsx_packs_h( ni0, ni2 );
510 ni4 = lsx_packs_h( ni4, ni6 );
511
512 __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
513 __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
514
515 }
516#else
517 GGML_UNUSED(nb);
518 // scalar
519 quantize_row_q8_0_ref(x, y, k);
520#endif
521}
522
523void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
524 assert(k % QK8_1 == 0);
525 const int nb = k / QK8_1;
526
527 block_q8_1 * GGML_RESTRICT y = vy;
528
529#if defined(__loongarch_asx)
530 for (int i = 0; i < nb; i++) {
531 __m256 v0 = (__m256)__lasx_xvld( x , 0 );
532 __m256 v1 = (__m256)__lasx_xvld( x , 32 );
533 __m256 v2 = (__m256)__lasx_xvld( x , 64 );
534 __m256 v3 = (__m256)__lasx_xvld( x , 96 );
535 x += 32;
536
537 // Compute max(abs(e)) for the block
538 const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
539 __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
540 max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
541 max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
542 max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
543
544 __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
545 max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
546 __m128 tmp = max4;
547 max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
548 const float max_scalar = ((v4f32)max4)[0];
549
550 // Quantize these floats
551 const float d = max_scalar / 127.f;
552 y[i].d = GGML_CPU_FP32_TO_FP16(d);
553 const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
554 const __m256 mul = __lasx_xvreplfr2vr_s( id );
555
556 // Apply the multiplier
557 v0 = __lasx_xvfmul_s( v0, mul );
558 v1 = __lasx_xvfmul_s( v1, mul );
559 v2 = __lasx_xvfmul_s( v2, mul );
560 v3 = __lasx_xvfmul_s( v3, mul );
561
562 // Round to nearest integer
563 __m256i i0 = __lasx_xvftintrne_w_s( v0 );
564 __m256i i1 = __lasx_xvftintrne_w_s( v1 );
565 __m256i i2 = __lasx_xvftintrne_w_s( v2 );
566 __m256i i3 = __lasx_xvftintrne_w_s( v3 );
567
568 __m128i ni0 = lasx_extracti128(i0, 0);
569 __m128i ni1 = lasx_extracti128( i0, 1);
570 __m128i ni2 = lasx_extracti128( i1, 0);
571 __m128i ni3 = lasx_extracti128( i1, 1);
572 __m128i ni4 = lasx_extracti128( i2, 0 );
573 __m128i ni5 = lasx_extracti128( i2, 1);
574 __m128i ni6 = lasx_extracti128( i3, 0);
575 __m128i ni7 = lasx_extracti128( i3, 1);
576
577 // Compute the sum of the quants and set y[i].s
578 const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));
579 const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));
580 y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));
581
582 // Convert int32 to int16
583 ni0 = lsx_packs_w( ni0, ni1 );
584 ni2 = lsx_packs_w( ni2, ni3 );
585 ni4 = lsx_packs_w( ni4, ni5 );
586 ni6 = lsx_packs_w( ni6, ni7 );
587 // Convert int16 to int8
588 ni0 = lsx_packs_h( ni0, ni2 );
589 ni4 = lsx_packs_h( ni4, ni6 );
590
591 __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
592 __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
593 }
594#else
595 GGML_UNUSED(nb);
596 // scalar
597 quantize_row_q8_1_ref(x, y, k);
598#endif
599}
600
601
602//===================================== Dot products =================================
603
604//
605// Helper functions
606//
607
608#if defined(__loongarch_asx)
609// shuffles to pick the required scales in dot products
610static inline __m256i get_scale_shuffle_q3k(int i) {
611 static const uint8_t k_shuffle[128] = {
612 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
613 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
614 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
615 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
616 };
617 return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
618}
619static inline __m256i get_scale_shuffle_k4(int i) {
620 static const uint8_t k_shuffle[256] = {
621 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
622 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
623 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
624 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
625 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
626 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
627 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
628 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
629 };
630 return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
631}
632static inline __m128i get_scale_shuffle(int i) {
633 static const uint8_t k_shuffle[128] = {
634 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
635 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
636 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
637 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
638 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
639 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
640 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
641 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
642 };
643 return __lsx_vld((const __m128i*)k_shuffle + i, 0);
644}
645#endif
646
647void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
648 const int qk = QK8_0;
649 const int nb = n / qk;
650
651 assert(n % qk == 0);
652 assert(nrc == 1);
653 UNUSED(nrc);
654 UNUSED(bx);
655 UNUSED(by);
656 UNUSED(bs);
657
658 const block_q4_0 * GGML_RESTRICT x = vx;
659 const block_q8_0 * GGML_RESTRICT y = vy;
660
661 int ib = 0;
662 float sumf = 0;
663
664#if defined(__loongarch_asx)
665 // Initialize accumulator with zeros
666 __m256 acc = (__m256)__lasx_xvldi(0);
667
668 // Main loop
669 for (; ib < nb; ++ib) {
670 /* Compute combined scale for the block */
671 const __m256 d = __lasx_xvreplfr2vr_s( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
672
673 __m256i qx = bytes_from_nibbles_32(x[ib].qs);
674
675 // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
676 const __m256i off = __lasx_xvreplgr2vr_b( 8 );
677 qx = __lasx_xvsub_b( qx, off );
678
679 __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
680
681 const __m256 q = mul_sum_i8_pairs_float(qx, qy);
682
683 /* Multiply q with scale and accumulate */
684 acc = __lasx_xvfmadd_s( d, q, acc );
685 }
686
687 sumf = hsum_float_8(acc);
688
689#elif defined(__loongarch_sx)
690 // set constants
691 const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
692 const __m128i off = __lsx_vreplgr2vr_b(8);
693
694 // Initialize accumulator with zeros
695 __m128 acc_0 = (__m128)__lsx_vldi(0);
696 __m128 acc_1 = (__m128)__lsx_vldi(0);
697 __m128 acc_2 = (__m128)__lsx_vldi(0);
698 __m128 acc_3 = (__m128)__lsx_vldi(0);
699
700 for (; ib + 1 < nb; ib += 2) {
701
702 // Compute combined scale for the block 0 and 1
703 const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
704 const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};
705
706 const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
707
708 __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
709 __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
710 bx_0 = __lsx_vsub_b(bx_0, off);
711 const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
712
713 __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
714 __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
715 bx_1 = __lsx_vsub_b(bx_1, off);
716 const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
717
718 // Compute combined scale for the block 2 and 3
719 const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);
720 const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};
721
722 const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
723
724 __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
725 __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
726 bx_2 = __lsx_vsub_b(bx_2, off);
727 const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
728
729 __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
730 __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
731 bx_3 = __lsx_vsub_b(bx_3, off);
732 const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
733
734 // Convert int32_t to float
735 __m128 p0 = __lsx_vffint_s_w(i32_0);
736 __m128 p1 = __lsx_vffint_s_w(i32_1);
737 __m128 p2 = __lsx_vffint_s_w(i32_2);
738 __m128 p3 = __lsx_vffint_s_w(i32_3);
739
740 // Apply the scale
741 __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );
742 __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );
743 __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );
744 __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );
745
746 // Acummulate
747 acc_0 = __lsx_vfadd_s(p0_d, acc_0);
748 acc_1 = __lsx_vfadd_s(p1_d, acc_1);
749 acc_2 = __lsx_vfadd_s(p2_d, acc_2);
750 acc_3 = __lsx_vfadd_s(p3_d, acc_3);
751 }
752
753 sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
754
755#endif
756 for (; ib < nb; ++ib) {
757 int sumi0 = 0;
758 int sumi1 = 0;
759
760 for (int j = 0; j < qk/2; ++j) {
761 const int v0 = (x[ib].qs[j] & 0x0F) - 8;
762 const int v1 = (x[ib].qs[j] >> 4) - 8;
763
764 sumi0 += (v0 * y[ib].qs[j]);
765 sumi1 += (v1 * y[ib].qs[j + qk/2]);
766 }
767
768 int sumi = sumi0 + sumi1;
769 sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
770 }
771
772 *s = sumf;
773}
774
775void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
776 const int qk = QK8_1;
777 const int nb = n / qk;
778
779 assert(n % qk == 0);
780 assert(nrc == 1);
781 UNUSED(nrc);
782 UNUSED(bx);
783 UNUSED(by);
784 UNUSED(bs);
785
786 const block_q4_1 * GGML_RESTRICT x = vx;
787 const block_q8_1 * GGML_RESTRICT y = vy;
788
789 int ib = 0;
790 float sumf = 0;
791
792#if defined(__loongarch_asx)
793 // Initialize accumulator with zeros
794 __m256 acc = (__m256)__lasx_xvldi(0);
795
796 float summs = 0;
797
798 // Main loop
799 for (; ib < nb; ++ib) {
800 const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
801 const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);
802
803 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
804
805 const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
806 const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
807
808 // Compute combined scales
809 const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
810
811 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
812 const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
813 const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
814
815 const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
816
817 // Accumulate d0*d1*x*y
818 acc = __lasx_xvfmadd_s( d0d1, xy, acc );
819 }
820
821 sumf = hsum_float_8(acc) + summs;
822
823 *s = sumf;
824#else
825 UNUSED(nb);
826 UNUSED(x);
827 UNUSED(y);
828 UNUSED(ib);
829 UNUSED(sumf);
830 ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
831#endif
832}
833
834void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
835 const int qk = QK8_0;
836 const int nb = n / qk;
837
838 int ib = 0;
839 float sumf = 0;
840
841 assert(n % qk == 0);
842 assert(qk == QK5_0);
843 assert(nrc == 1);
844 UNUSED(nrc);
845 UNUSED(bx);
846 UNUSED(by);
847 UNUSED(bs);
848
849 const block_q5_0 * GGML_RESTRICT x = vx;
850 const block_q8_0 * GGML_RESTRICT y = vy;
851
852#if defined(__loongarch_asx)
853 // Initialize accumulator with zeros
854 __m256 acc = (__m256)__lasx_xvldi(0);
855
856 // Main loop
857 for (; ib < nb; ++ib) {
858 /* Compute combined scale for the block */
859 const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); //FIXME
860
861 __m256i qx = bytes_from_nibbles_32(x[ib].qs);
862 __m256i bxhi = bytes_from_bits_32(x[ib].qh);
863 bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
864 qx = __lasx_xvor_v(qx, bxhi);
865
866 __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
867
868 const __m256 q = mul_sum_i8_pairs_float(qx, qy);
869
870 /* Multiply q with scale and accumulate */
871 acc = __lasx_xvfmadd_s(d, q, acc);
872 }
873
874 sumf = hsum_float_8(acc);
875
876 *s = sumf;
877#else
878 UNUSED(nb);
879 UNUSED(ib);
880 UNUSED(sumf);
881 UNUSED(x);
882 UNUSED(y);
883 ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
884#endif
885}
886
887void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
888 const int qk = QK8_1;
889 const int nb = n / qk;
890
891 int ib = 0;
892 float sumf = 0;
893
894 assert(n % qk == 0);
895 assert(qk == QK5_1);
896 assert(nrc == 1);
897 UNUSED(nrc);
898 UNUSED(bx);
899 UNUSED(by);
900 UNUSED(bs);
901
902 const block_q5_1 * GGML_RESTRICT x = vx;
903 const block_q8_1 * GGML_RESTRICT y = vy;
904
905#if defined(__loongarch_asx)
906 // Initialize accumulator with zeros
907 __m256 acc = (__m256)__lasx_xvldi(0);
908
909 float summs = 0.0f;
910
911 // Main loop
912 for (; ib < nb; ++ib) {
913 const __m256 dx = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d));
914
915 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
916
917 __m256i qx = bytes_from_nibbles_32(x[ib].qs);
918 __m256i bxhi = bytes_from_bits_32(x[ib].qh);
919 bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
920 qx = __lasx_xvor_v(qx, bxhi);
921
922 const __m256 dy = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib].d));
923 const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
924
925 const __m256 q = mul_sum_us8_pairs_float(qx, qy);
926
927 acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
928 }
929
930 sumf = hsum_float_8(acc) + summs;
931
932 *s = sumf;
933#else
934 UNUSED(nb);
935 UNUSED(ib);
936 UNUSED(sumf);
937 UNUSED(x);
938 UNUSED(y);
939 ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
940#endif
941}
942
943void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
944 const int qk = QK8_0;
945 const int nb = n / qk;
946
947 assert(n % qk == 0);
948 assert(nrc == 1);
949 UNUSED(nrc);
950 UNUSED(bx);
951 UNUSED(by);
952 UNUSED(bs);
953
954 const block_q8_0 * GGML_RESTRICT x = vx;
955 const block_q8_0 * GGML_RESTRICT y = vy;
956
957 int ib = 0;
958 float sumf = 0;
959
960#if defined(__loongarch_asx)
961 // Initialize accumulator with zeros
962 __m256 acc = (__m256)__lasx_xvldi(0);
963
964 // Main loop
965 for (; ib < nb; ++ib) {
966 // Compute combined scale for the block
967 const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
968 __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
969 __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
970
971 const __m256 q = mul_sum_i8_pairs_float(qx, qy);
972
973 // Multiply q with scale and accumulate
974 acc = __lasx_xvfmadd_s( d, q, acc );
975 }
976
977 sumf = hsum_float_8(acc);
978
979 *s = sumf;
980#else
981 UNUSED(nb);
982 UNUSED(ib);
983 UNUSED(sumf);
984 UNUSED(x);
985 UNUSED(y);
986 ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
987#endif
988}
989
990void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
991 assert(nrc == 1);
992 UNUSED(nrc);
993 UNUSED(bx);
994 UNUSED(by);
995 UNUSED(bs);
996
997 const block_q2_K * GGML_RESTRICT x = vx;
998 const block_q8_K * GGML_RESTRICT y = vy;
999
1000 const int nb = n / QK_K;
1001
1002#if defined __loongarch_asx
1003
1004 __m256 acc = (__m256)__lasx_xvldi(0);
1005
1006 for (int i = 0; i < nb; ++i) {
1007
1008 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1009 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1010
1011 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1012 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1013
1014 const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
1015 const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
1016 const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
1017 const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
1018
1019 acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
1020
1021 const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1022 const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1023
1024 __m256i sumi = __lasx_xvldi(0);
1025
1026 for (int j = 0; j < QK_K/128; ++j) {
1027
1028 const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;
1029
1030 const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1031 const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1032 const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1033 const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1034
1035 const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
1036 const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
1037 const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
1038 const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
1039
1040 __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
1041 __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
1042 __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
1043 __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
1044
1045 p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
1046 p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
1047 p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
1048 p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
1049
1050 p0 = __lasx_xvadd_w(p0, p1);
1051 p2 = __lasx_xvadd_w(p2, p3);
1052
1053 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));
1054 }
1055
1056 acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1057
1058 }
1059
1060 *s = hsum_float_8(acc);
1061
1062#else
1063 UNUSED(x);
1064 UNUSED(y);
1065 UNUSED(nb);
1066 ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1067#endif
1068}
1069
1070void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1071 assert(n % QK_K == 0);
1072 assert(nrc == 1);
1073 UNUSED(nrc);
1074 UNUSED(bx);
1075 UNUSED(by);
1076 UNUSED(bs);
1077
1078 const uint32_t kmask1 = 0x03030303;
1079 const uint32_t kmask2 = 0x0f0f0f0f;
1080
1081 const block_q3_K * GGML_RESTRICT x = vx;
1082 const block_q8_K * GGML_RESTRICT y = vy;
1083
1084 const int nb = n / QK_K;
1085
1086#if defined __loongarch_asx
1087
1088 const __m128i m32 = __lsx_vreplgr2vr_b(32);
1089
1090 __m256 acc = (__m256)__lasx_xvldi(0);
1091
1092 uint32_t aux[3];
1093
1094 for (int i = 0; i < nb; ++i) {
1095
1096 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1097 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1098 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1099 // Set up scales
1100 memcpy(aux, x[i].scales, 12);
1101 __m128i scales128 = lsx_set_w(
1102 ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1103 ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1104 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1105 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1106 scales128 = __lsx_vsub_b(scales128, m32);
1107
1108 const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1109 const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1110
1111 // high bit
1112 const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
1113
1114 // integer accumulator
1115 __m256i sumi = __lasx_xvldi(0);
1116
1117 for (int j = 0; j < QK_K/128; ++j) {
1118 // load low 2 bits
1119 const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
1120
1121 // prepare low and high bits
1122 const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
1123 const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
1124 const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
1125 const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
1126 const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
1127 const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
1128 const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
1129 const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
1130 const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
1131 const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
1132 const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
1133 const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
1134
1135 // load Q8 quants
1136 const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1137 const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1138 const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1139 const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1140
1141 __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
1142 __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
1143 __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
1144 __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
1145
1146 // multiply with scales
1147 p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
1148 p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
1149 p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
1150 p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
1151
1152 // accumulate
1153 p16_0 = __lasx_xvadd_w(p16_0, p16_1);
1154 p16_2 = __lasx_xvadd_w(p16_2, p16_3);
1155 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
1156 }
1157 // multiply with block scale and accumulate
1158 acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1159 }
1160
1161 *s = hsum_float_8(acc);
1162
1163#else
1164 UNUSED(kmask1);
1165 UNUSED(kmask2);
1166 UNUSED(x);
1167 UNUSED(y);
1168 UNUSED(nb);
1169 ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1170#endif
1171}
1172
1173void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1174 assert(n % QK_K == 0);
1175 assert(nrc == 1);
1176 UNUSED(nrc);
1177 UNUSED(bx);
1178 UNUSED(by);
1179 UNUSED(bs);
1180
1181 const block_q4_K * GGML_RESTRICT x = vx;
1182 const block_q8_K * GGML_RESTRICT y = vy;
1183
1184 const int nb = n / QK_K;
1185
1186 static const uint32_t kmask1 = 0x3f3f3f3f;
1187 static const uint32_t kmask2 = 0x0f0f0f0f;
1188 static const uint32_t kmask3 = 0x03030303;
1189
1190 uint32_t utmp[4];
1191
1192#if defined __loongarch_asx
1193
1194 __m256 acc = (__m256)__lasx_xvldi(0);
1195 __m128 acc_m = (__m128)__lsx_vldi(0);
1196
1197 for (int i = 0; i < nb; ++i) {
1198
1199 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1200 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1201
1202 memcpy(utmp, x[i].scales, 12);
1203 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1204 const uint32_t uaux = utmp[1] & kmask1;
1205 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1206 utmp[2] = uaux;
1207 utmp[0] &= kmask1;
1208
1209 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1210 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1211
1212 const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
1213 const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
1214 const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
1215
1216 const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
1217 const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
1218 const __m128i prod = lsx_madd_h(mins128, q8s);
1219 acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
1220
1221 const __m256i scales = lasx_insertf128(scales128, scales128);
1222
1223 __m256i sumi = __lasx_xvldi(0);
1224
1225 for (int j = 0; j < QK_K/64; ++j) {
1226
1227 const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
1228 const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
1229
1230 const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1231 const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
1232 const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
1233
1234 const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1235 __m256i p16l = lasx_madd_h_b(q4l, q8l);
1236 p16l = lasx_madd_h(scale_l, p16l);
1237
1238 const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1239 __m256i p16h = lasx_madd_h_b(q4h, q8h);
1240 p16h = lasx_madd_h(scale_h, p16h);
1241 const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
1242
1243 sumi = __lasx_xvadd_w(sumi, sumj);
1244 }
1245
1246 __m256 vd = __lasx_xvreplfr2vr_s(d);
1247 acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
1248
1249 }
1250
1251 acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
1252 __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
1253 acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
1254
1255
1256 *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
1257
1258#else
1259 UNUSED(x);
1260 UNUSED(y);
1261 UNUSED(nb);
1262 UNUSED(kmask1);
1263 UNUSED(kmask2);
1264 UNUSED(kmask3);
1265 UNUSED(utmp);
1266 ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1267#endif
1268}
1269
1270void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1271 assert(n % QK_K == 0);
1272 assert(nrc == 1);
1273 UNUSED(nrc);
1274 UNUSED(bx);
1275 UNUSED(by);
1276 UNUSED(bs);
1277
1278 const block_q5_K * GGML_RESTRICT x = vx;
1279 const block_q8_K * GGML_RESTRICT y = vy;
1280
1281 const int nb = n / QK_K;
1282
1283 static const uint32_t kmask1 = 0x3f3f3f3f;
1284 static const uint32_t kmask2 = 0x0f0f0f0f;
1285 static const uint32_t kmask3 = 0x03030303;
1286
1287 uint32_t utmp[4];
1288
1289#if defined __loongarch_asx
1290
1291 __m256 acc = (__m256)__lasx_xvldi(0);
1292 __m128 acc_m = (__m128)__lsx_vldi(0);
1293
1294 for (int i = 0; i < nb; ++i) {
1295
1296 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1297 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1298
1299 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1300 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1301
1302 memcpy(utmp, x[i].scales, 12);
1303 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1304 const uint32_t uaux = utmp[1] & kmask1;
1305 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1306 utmp[2] = uaux;
1307 utmp[0] &= kmask1;
1308
1309 const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
1310 const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
1311 const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
1312
1313 const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
1314 const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
1315 const __m128i prod = lsx_madd_h(mins128, q8s);
1316 acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
1317
1318 const __m256i scales = lasx_insertf128(scales128, scales128);
1319
1320 const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
1321
1322 __m256i sumi = __lasx_xvldi(0);
1323
1324 for (int j = 0; j < QK_K/64; ++j) {
1325
1326 const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
1327 const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
1328
1329 const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
1330
1331 const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
1332 const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
1333 const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
1334 const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
1335 const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
1336 const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
1337
1338 const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1339 const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1340
1341 __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
1342 __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
1343
1344 p16_0 = lasx_madd_h(scale_0, p16_0);
1345 p16_1 = lasx_madd_h(scale_1, p16_1);
1346
1347 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
1348
1349 }
1350
1351 __m256 vd = __lasx_xvreplfr2vr_s(d);
1352 acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
1353
1354 }
1355
1356 acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
1357 acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
1358
1359 *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
1360
1361#else
1362 UNUSED(x);
1363 UNUSED(y);
1364 UNUSED(nb);
1365 UNUSED(kmask1);
1366 UNUSED(kmask2);
1367 UNUSED(kmask3);
1368 UNUSED(utmp);
1369 ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1370#endif
1371}
1372
1373void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1374 assert(n % QK_K == 0);
1375 assert(nrc == 1);
1376 UNUSED(nrc);
1377 UNUSED(bx);
1378 UNUSED(by);
1379 UNUSED(bs);
1380
1381 const block_q6_K * GGML_RESTRICT x = vx;
1382 const block_q8_K * GGML_RESTRICT y = vy;
1383
1384 const int nb = n / QK_K;
1385
1386#if defined __loongarch_asx
1387
1388 const __m256i m32s = __lasx_xvreplgr2vr_b(32);
1389
1390 __m256 acc = (__m256)__lasx_xvldi(0);
1391
1392 for (int i = 0; i < nb; ++i) {
1393
1394 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1395
1396 const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1397 const uint8_t * GGML_RESTRICT qh = x[i].qh;
1398 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1399
1400 const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
1401 const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1402 const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1403
1404 __m256i sumi = __lasx_xvldi(0);
1405
1406 for (int j = 0; j < QK_K/128; ++j) {
1407
1408 const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1409 const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1410 const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
1411
1412 const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
1413 const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
1414 const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
1415 const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
1416
1417 const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
1418 const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
1419 const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
1420 const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
1421
1422 const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1423 const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1424 const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1425 const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1426
1427 __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
1428 __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
1429 __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
1430 __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
1431
1432 p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
1433 p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
1434 p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
1435 p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
1436
1437 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
1438 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
1439 }
1440
1441 acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1442 }
1443
1444 *s = hsum_float_8(acc);
1445
1446#else
1447 UNUSED(x);
1448 UNUSED(y);
1449 UNUSED(nb);
1450 ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1451#endif
1452}
1453
1454#if defined(__loongarch_asx)
1455static const int8_t keven_signs_q2xs[1024] = {
1456 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1457 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
1458 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
1459 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
1460 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
1461 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
1462 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
1463 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
1464 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
1465 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
1466 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
1467 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
1468 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
1469 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
1470 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
1471 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
1472 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
1473 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
1474 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
1475 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
1476 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
1477 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
1478 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
1479 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
1480 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
1481 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
1482 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
1483 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
1484 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
1485 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
1486 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1487 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
1488};
1489#endif
1490
1491void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1492 assert(n % QK_K == 0);
1493 assert(nrc == 1);
1494 UNUSED(nrc);
1495 UNUSED(bx);
1496 UNUSED(by);
1497 UNUSED(bs);
1498
1499 const block_iq2_xxs * GGML_RESTRICT x = vx;
1500 const block_q8_K * GGML_RESTRICT y = vy;
1501
1502 const int nb = n / QK_K;
1503
1504#if defined(__loongarch_asx)
1505
1506 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1507
1508 uint32_t aux32[4];
1509 const uint8_t * aux8 = (const uint8_t *)aux32;
1510
1511 __m256 accumf = (__m256)__lasx_xvldi(0);
1512 for (int i = 0; i < nb; ++i) {
1513 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1514 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1515 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1516 __m256i sumi1 = __lasx_xvldi(0);
1517 __m256i sumi2 = __lasx_xvldi(0);
1518 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1519 const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1520 const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1521 memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
1522
1523 const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
1524 const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
1525 const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
1526 signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
1527 const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
1528 signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
1529 const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
1530 const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
1531 const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1532 const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1533 const uint16_t ls1 = aux32[1] >> 28;
1534 const uint16_t ls2 = aux32[3] >> 28;
1535 const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1536 const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1537 sumi1 = __lasx_xvadd_w(sumi1, p1);
1538 sumi2 = __lasx_xvadd_w(sumi2, p2);
1539 }
1540
1541 accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1542 }
1543
1544 *s = 0.125f * hsum_float_8(accumf);
1545
1546#else
1547 UNUSED(x);
1548 UNUSED(y);
1549 UNUSED(nb);
1550 ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1551#endif
1552}
1553
1554void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1555 assert(n % QK_K == 0);
1556 assert(nrc == 1);
1557 UNUSED(nrc);
1558 UNUSED(bx);
1559 UNUSED(by);
1560 UNUSED(bs);
1561
1562 const block_iq2_xs * GGML_RESTRICT x = vx;
1563 const block_q8_K * GGML_RESTRICT y = vy;
1564
1565 const int nb = n / QK_K;
1566
1567#if defined(__loongarch_asx)
1568
1569 const __m256i mone = __lasx_xvreplgr2vr_b(1);
1570 static const char block_sign_shuffle_mask_1[32] = {
1571 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
1572 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
1573 };
1574 static const char block_sign_shuffle_mask_2[32] = {
1575 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
1576 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
1577 };
1578 static const uint8_t bit_selector_mask_bytes[32] = {
1579 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1580 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1581 };
1582
1583 const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);
1584 const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
1585 const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
1586
1587 static const uint8_t k_bit_helper[32] = {
1588 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
1589 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
1590 };
1591 const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);
1592 const __m256i m511 = __lasx_xvreplgr2vr_h(511);
1593 const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
1594 const __m128i m1 = __lsx_vreplgr2vr_b(1);
1595
1596 uint64_t aux64;
1597
1598 // somewhat hacky, but gives a significant boost in performance
1599 __m256i aux_gindex;
1600 const uint16_t * gindex = (const uint16_t *)&aux_gindex;
1601
1602 __m256 accumf = (__m256)__lasx_xvldi(0);
1603 for (int i = 0; i < nb; ++i) {
1604 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1605 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1606 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1607
1608 memcpy(&aux64, x[i].scales, 8);
1609 __m128i stmp = __lsx_vreplgr2vr_d(aux64);
1610 stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));
1611 const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);
1612
1613 __m256i sumi1 = __lasx_xvldi(0);
1614 __m256i sumi2 = __lasx_xvldi(0);
1615 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
1616
1617 const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16;
1618 aux_gindex = __lasx_xvand_v(q2_data, m511);
1619
1620 const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);
1621 const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);
1622 const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);
1623
1624 const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
1625 const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);
1626
1627 const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1628 const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1629 const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1630 const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1631
1632 const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
1633 iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
1634 const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
1635 iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
1636 const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
1637 iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
1638 const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
1639 iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
1640
1641 const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
1642 const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
1643 const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
1644 const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
1645
1646 __m256i signs;
1647 signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
1648 signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1649 const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
1650
1651 signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);
1652 signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1653 const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
1654
1655 signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);
1656 signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1657 const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);
1658
1659 signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);
1660 signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1661 const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);
1662
1663 const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1664 const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1665 const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3);
1666 const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4);
1667
1668 const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));
1669 const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));
1670 const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));
1671 const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));
1672
1673 sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));
1674 sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));
1675 sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));
1676 sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));
1677 }
1678
1679 accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1680
1681 }
1682
1683 *s = 0.125f * hsum_float_8(accumf);
1684
1685#else
1686 UNUSED(x);
1687 UNUSED(y);
1688 UNUSED(nb);
1689 ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1690#endif
1691}
1692
1693void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1694 assert(n % QK_K == 0);
1695 assert(nrc == 1);
1696 UNUSED(nrc);
1697 UNUSED(bx);
1698 UNUSED(by);
1699 UNUSED(bs);
1700
1701 const block_iq2_s * GGML_RESTRICT x = vx;
1702 const block_q8_K * GGML_RESTRICT y = vy;
1703
1704 const int nb = n / QK_K;
1705
1706#if defined(__loongarch_asx)
1707
1708 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
1709 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
1710 };
1711
1712 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1713 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1714 };
1715
1716
1717 const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
1718 const __m128i m1 = __lsx_vreplgr2vr_b(1);
1719
1720 const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
1721 const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
1722 uint64_t aux64;
1723
1724 __m256 accumf = (__m256)__lasx_xvldi(0);
1725 for (int i = 0; i < nb; ++i) {
1726 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1727 const uint8_t * GGML_RESTRICT qs = x[i].qs;
1728 const uint8_t * GGML_RESTRICT qh = x[i].qh;
1729 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
1730 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1731
1732 __m128i tmp1;
1733 memcpy(&aux64, x[i].scales, 8);
1734 tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);
1735 tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);
1736 const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);
1737 const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
1738
1739 __m256i sumi1 = __lasx_xvldi(0);
1740 __m256i sumi2 = __lasx_xvldi(0);
1741 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1742 const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1743 const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1744 const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
1745 iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
1746 iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
1747 iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
1748 const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
1749 iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
1750 iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
1751 iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
1752 qs += 8;
1753
1754 __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));
1755 aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1756 const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
1757 const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
1758
1759 aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));
1760 aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1761 const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
1762 const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
1763
1764 signs += 4;
1765
1766 const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
1767 const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
1768
1769 const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));
1770 const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));
1771 sumi1 = __lasx_xvadd_w(sumi1, p1);
1772 sumi2 = __lasx_xvadd_w(sumi2, p2);
1773 }
1774
1775 accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1776 }
1777
1778 *s = 0.125f * hsum_float_8(accumf);
1779
1780#else
1781 UNUSED(x);
1782 UNUSED(y);
1783 UNUSED(nb);
1784 ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1785#endif
1786}
1787
1788void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1789 assert(n % QK_K == 0);
1790 assert(nrc == 1);
1791 UNUSED(nrc);
1792 UNUSED(bx);
1793 UNUSED(by);
1794 UNUSED(bs);
1795
1796 const block_iq3_xxs * GGML_RESTRICT x = vx;
1797 const block_q8_K * GGML_RESTRICT y = vy;
1798
1799 const int nb = n / QK_K;
1800
1801#if defined(__loongarch_asx)
1802
1803 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1804
1805 uint32_t aux32[2];
1806
1807 __m256 accumf = (__m256)__lasx_xvldi(0);
1808 for (int i = 0; i < nb; ++i) {
1809 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1810 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1811 const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
1812 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1813 __m256i sumi1 = __lasx_xvldi(0);
1814 __m256i sumi2 = __lasx_xvldi(0);
1815 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1816 const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1817 const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1818 const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
1819 iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
1820 q3 += 8;
1821 const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
1822 iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
1823 q3 += 8;
1824 memcpy(aux32, gas, 8); gas += 8;
1825
1826 const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
1827 signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
1828 const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
1829 signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
1830 const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
1831 const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
1832 const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1833 const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1834 const uint16_t ls1 = aux32[0] >> 28;
1835 const uint16_t ls2 = aux32[1] >> 28;
1836
1837 const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1838 const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1839 sumi1 = __lasx_xvadd_w(sumi1, p1);
1840 sumi2 = __lasx_xvadd_w(sumi2, p2);
1841 }
1842
1843 accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1844 }
1845
1846 *s = 0.25f * hsum_float_8(accumf);
1847
1848#else
1849 UNUSED(x);
1850 UNUSED(y);
1851 UNUSED(nb);
1852 ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1853#endif
1854}
1855
1856void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1857 assert(n % QK_K == 0);
1858 assert(nrc == 1);
1859 UNUSED(nrc);
1860 UNUSED(bx);
1861 UNUSED(by);
1862 UNUSED(bs);
1863
1864 const block_iq3_s * GGML_RESTRICT x = vx;
1865 const block_q8_K * GGML_RESTRICT y = vy;
1866
1867 const int nb = n / QK_K;
1868
1869#if defined(__loongarch_asx)
1870
1871 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
1872 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
1873 };
1874
1875 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1876 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1877 };
1878
1879 const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
1880 const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
1881
1882 __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);
1883 const __m256i idx_mask = __lasx_xvreplgr2vr_w(256);
1884
1885 typedef union {
1886 __m256i vec[2];
1887 uint32_t index[16];
1888 } index_t;
1889
1890 index_t idx;
1891
1892 __m256 accumf = (__m256)__lasx_xvldi(0);
1893 for (int i = 0; i < nb; ++i) {
1894 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1895 const uint8_t * GGML_RESTRICT qs = x[i].qs;
1896 const uint8_t * GGML_RESTRICT qh = x[i].qh;
1897 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
1898 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1899 __m256i sumi1 = __lasx_xvldi(0);
1900 __m256i sumi2 = __lasx_xvldi(0);
1901 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1902 const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1903 const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1904 const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;
1905 idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);
1906 idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);
1907 idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);
1908 idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);
1909 idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));
1910 idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));
1911
1912 // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
1913 //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
1914 //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
1915 const __m256i q2_1 = lasx_set_w(
1916 iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
1917 iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
1918 );
1919 const __m256i q2_2 = lasx_set_w(
1920 iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
1921 iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
1922 );
1923
1924 __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));
1925 aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1926 const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
1927 const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
1928
1929 aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));
1930 aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1931 const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
1932 const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
1933
1934 signs += 4;
1935
1936 const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1937 const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1938 const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
1939 const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
1940 const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1941 const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1942 sumi1 = __lasx_xvadd_w(sumi1, p1);
1943 sumi2 = __lasx_xvadd_w(sumi2, p2);
1944 }
1945
1946 accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1947 }
1948
1949 *s = hsum_float_8(accumf);
1950
1951#else
1952 UNUSED(x);
1953 UNUSED(y);
1954 UNUSED(nb);
1955 ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1956#endif
1957}
1958
1959#if defined(__loongarch_asx)
1960static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
1961 const __m256i a = __lasx_xvmulwev_h_b(x, y);
1962 const __m256i b = __lasx_xvmulwod_h_b(x, y);
1963 return __lasx_xvadd_h(a, b);
1964}
1965#endif
1966
1967void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1968 assert(n % QK_K == 0);
1969 assert(nrc == 1);
1970 UNUSED(nrc);
1971 UNUSED(bx);
1972 UNUSED(by);
1973 UNUSED(bs);
1974
1975 const block_iq1_s * GGML_RESTRICT x = vx;
1976 const block_q8_K * GGML_RESTRICT y = vy;
1977
1978 const int nb = n / QK_K;
1979
1980#if defined(__loongarch_asx)
1981
1982 __m256 accum = (__m256)__lasx_xvldi(0);
1983 float accum1 = 0;
1984 for (int i = 0; i < nb; ++i) {
1985
1986 const int8_t * q8 = y[i].qs;
1987 const uint8_t * qs = x[i].qs;
1988 const uint16_t * qh = x[i].qh;
1989
1990 __m256i sumi = __lasx_xvldi(0);
1991 int sumi1 = 0;
1992 for (int ib = 0; ib < QK_K/32; ib += 2) {
1993 __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);
1994 q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);
1995 q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
1996 q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
1997
1998 __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
1999 q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
2000 q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
2001 q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);
2002
2003 qs += 8;
2004 const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
2005 const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
2006
2007 const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
2008 const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
2009 const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
2010 const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
2011
2012 __m256i tmp1, tmp5, tmp6;
2013 tmp1 = __lasx_xvreplgr2vr_h(ls1);
2014 tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);
2015 tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);
2016 const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);
2017
2018 tmp1 = __lasx_xvreplgr2vr_h(ls2);
2019 tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);
2020 tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);
2021 const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);
2022
2023 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));
2024 sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
2025 + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
2026 }
2027
2028 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2029 accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
2030 accum1 += d * sumi1;
2031 }
2032
2033 *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
2034
2035#else
2036 UNUSED(x);
2037 UNUSED(y);
2038 UNUSED(nb);
2039 ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2040#endif
2041}
2042
2043void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2044 assert(nrc == 1);
2045 UNUSED(nrc);
2046 UNUSED(bx);
2047 UNUSED(by);
2048 UNUSED(bs);
2049 assert(n % QK4_NL == 0);
2050 static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
2051
2052 const block_iq4_nl * GGML_RESTRICT x = vx;
2053 const block_q8_0 * GGML_RESTRICT y = vy;
2054
2055 const int nb = n / QK4_NL;
2056
2057 int ib = 0;
2058 float sumf = 0;
2059
2060#if defined (__loongarch_asx)
2061
2062 const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2063 const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
2064 const __m256i mone = __lasx_xvreplgr2vr_h(1);
2065
2066 __m256 accum1 = (__m256)__lasx_xvldi(0);
2067 __m256 accum2 = (__m256)__lasx_xvldi(0);
2068 for (; ib + 1 < nb; ib += 2) {
2069 const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
2070 const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
2071 const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
2072 const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
2073 const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
2074 lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
2075 const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
2076 lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
2077 const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
2078 const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
2079 const __m256i p_1 = lasx_madd_h(p16_1, mone);
2080 const __m256i p_2 = lasx_madd_h(p16_2, mone);
2081 accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
2082 __lasx_xvffint_s_w(p_1), accum1);
2083 accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
2084 __lasx_xvffint_s_w(p_2), accum2);
2085 }
2086
2087 sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
2088
2089#endif
2090 for (; ib < nb; ++ib) {
2091 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
2092 int sumi1 = 0, sumi2 = 0;
2093 for (int j = 0; j < QK4_NL/2; ++j) {
2094 sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
2095 sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
2096 }
2097 sumf += d * (sumi1 + sumi2);
2098 }
2099 *s = sumf;
2100}
2101
2102void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2103 assert(nrc == 1);
2104 UNUSED(nrc);
2105 UNUSED(bx);
2106 UNUSED(by);
2107 UNUSED(bs);
2108 assert(n % QK_K == 0);
2109
2110 const block_iq4_xs * GGML_RESTRICT x = vx;
2111 const block_q8_K * GGML_RESTRICT y = vy;
2112
2113 const int nb = n / QK_K;
2114
2115#if defined(__loongarch_asx)
2116
2117 const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2118
2119 __m256 accum = (__m256)__lasx_xvldi(0);
2120
2121 for (int ibl = 0; ibl < nb; ++ibl) {
2122 const uint8_t * qs = x[ibl].qs;
2123 const int8_t * q8 = y[ibl].qs;
2124 uint16_t sh = x[ibl].scales_h;
2125 __m256i sumi1 = __lasx_xvldi(0);
2126 __m256i sumi2 = __lasx_xvldi(0);
2127 for (int ib = 0; ib < QK_K/32; ib += 2) {
2128 const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2129 const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2130 const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
2131 const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
2132 const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
2133 __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
2134 const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
2135 __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
2136 const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
2137 const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
2138 const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
2139 const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
2140 sh >>= 4;
2141 const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
2142 const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
2143 sumi1 = __lasx_xvadd_w(p_1, sumi1);
2144 sumi2 = __lasx_xvadd_w(p_2, sumi2);
2145 }
2146 accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
2147 __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);
2148 }
2149
2150 *s = hsum_float_8(accum);
2151
2152#else
2153 UNUSED(x);
2154 UNUSED(y);
2155 UNUSED(nb);
2156 ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2157#endif
2158}
2159