aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c')
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c4052
1 files changed, 4052 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c b/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c
new file mode 100644
index 0000000..b390ab6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c
@@ -0,0 +1,4052 @@
1#define GGML_COMMON_IMPL_C
2#include "ggml-common.h"
3#include "ggml-quants.h"
4#include "ggml-impl.h"
5#include "ggml-cpu.h"
6#include "simd-mappings.h"
7
8#include "../../quants.h"
9#include "../../ggml-cpu-impl.h"
10
11#include <math.h>
12#include <string.h>
13#include <assert.h>
14#include <float.h>
15#include <stdlib.h> // for qsort
16#include <stdio.h> // for GGML_ASSERT
17
18#define GROUP_MAX_EPS 1e-15f
19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20#define GROUP_MAX_EPS_IQ2_S 1e-8f
21#define GROUP_MAX_EPS_IQ1_M 1e-7f
22#define GROUP_MAX_EPS_IQ1_S 1e-12f
23
24#define UNUSED GGML_UNUSED
25
26#if defined(__ARM_NEON)
27#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
28#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
29#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
30#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
31#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
32#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
33#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
34#define B8(c,s ) B7(c,s, c), B7(c,s, s)
35
36// precomputed tables for expanding 8bits to 8 bytes:
37static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
38static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
39#endif
40
41void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
42 assert(QK8_0 == 32);
43 assert(k % QK8_0 == 0);
44 const int nb = k / QK8_0;
45
46 block_q8_0 * GGML_RESTRICT y = vy;
47
48#if defined(__ARM_NEON)
49 for (int i = 0; i < nb; i++) {
50 float32x4_t srcv [8];
51 float32x4_t asrcv[8];
52 float32x4_t amaxv[8];
53
54 for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
55 for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
56
57 for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
58 for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
59 for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
60
61 const float amax = vmaxvq_f32(amaxv[0]);
62
63 const float d = amax / ((1 << 7) - 1);
64 const float id = d ? 1.0f/d : 0.0f;
65
66 y[i].d = GGML_CPU_FP32_TO_FP16(d);
67
68 for (int j = 0; j < 8; j++) {
69 const float32x4_t v = vmulq_n_f32(srcv[j], id);
70 const int32x4_t vi = vcvtnq_s32_f32(v);
71
72 y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
73 y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
74 y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
75 y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
76 }
77 }
78#else
79 GGML_UNUSED(nb);
80 // scalar
81 quantize_row_q8_0_ref(x, y, k);
82#endif
83}
84
85void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
86 assert(k % QK8_1 == 0);
87 const int nb = k / QK8_1;
88
89 block_q8_1 * GGML_RESTRICT y = vy;
90#if defined(__ARM_NEON)
91 for (int i = 0; i < nb; i++) {
92 float32x4_t srcv [8];
93 float32x4_t asrcv[8];
94 float32x4_t amaxv[8];
95
96 for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
97 for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
98
99 for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
100 for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
101 for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
102
103 const float amax = vmaxvq_f32(amaxv[0]);
104
105 const float d = amax / ((1 << 7) - 1);
106 const float id = d ? 1.0f/d : 0.0f;
107
108 y[i].d = GGML_CPU_FP32_TO_FP16(d);
109
110 int32x4_t accv = vdupq_n_s32(0);
111
112 for (int j = 0; j < 8; j++) {
113 const float32x4_t v = vmulq_n_f32(srcv[j], id);
114 const int32x4_t vi = vcvtnq_s32_f32(v);
115
116 y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
117 y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
118 y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
119 y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
120
121 accv = vaddq_s32(accv, vi);
122 }
123
124 y[i].s = GGML_CPU_FP32_TO_FP16(d * vaddvq_s32(accv));
125 }
126#else
127 GGML_UNUSED(nb);
128 // scalar
129 quantize_row_q8_1_ref(x, y, k);
130#endif
131}
132
133// placeholder implementation for Apple targets
134void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
135 quantize_row_q8_K_ref(x, y, k);
136}
137
138//===================================== Dot products =================================
139
140void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
141 const int qk = QK8_0;
142 const int nb = n / qk;
143
144 assert(n % qk == 0);
145#if defined(__ARM_FEATURE_MATMUL_INT8)
146 assert((nrc == 2) || (nrc == 1));
147#else
148 assert(nrc == 1);
149#endif
150 UNUSED(nrc);
151 UNUSED(bx);
152 UNUSED(by);
153 UNUSED(bs);
154
155 const block_q4_0 * GGML_RESTRICT x = vx;
156 const block_q8_0 * GGML_RESTRICT y = vy;
157
158#if defined(__ARM_FEATURE_MATMUL_INT8)
159 if (nrc == 2) {
160 const block_q4_0 * GGML_RESTRICT vx0 = vx;
161 const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
162 const block_q8_0 * GGML_RESTRICT vy0 = vy;
163 const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
164
165 float32x4_t sumv0 = vdupq_n_f32(0.0f);
166
167 for (int i = 0; i < nb; i++) {
168 const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i];
169 const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i];
170 const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];
171 const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];
172
173 const uint8x16_t m4b = vdupq_n_u8(0x0F);
174 const int8x16_t s8b = vdupq_n_s8(0x8);
175
176 const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
177 const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
178
179 // 4-bit -> 8-bit
180 const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
181 const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
182 const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
183 const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
184
185 // sub 8
186 const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
187 const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
188 const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
189 const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
190
191 // load y
192 const int8x16_t y0_l = vld1q_s8(b_y0->qs);
193 const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
194 const int8x16_t y1_l = vld1q_s8(b_y1->qs);
195 const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
196
197 float32_t _scale[4] = {
198 GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
199 GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),
200 GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
201 GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)
202 };
203 float32x4_t scale = vld1q_f32(_scale);
204
205 int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
206 int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
207
208 int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
209 int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
210
211 int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
212 int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
213
214 int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
215 int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
216
217 sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
218 l1, r1)), l2, r2)), l3, r3))), scale);
219 }
220
221 float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
222 float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
223
224 vst1_f32(s, vget_low_f32 (sumv2));
225 vst1_f32(s + bs, vget_high_f32(sumv2));
226
227 return;
228 }
229#endif
230
231 int ib = 0;
232 float sumf = 0;
233
234#if defined(__ARM_FEATURE_SVE)
235 svfloat32_t sumv0 = svdup_n_f32(0.0f);
236 svfloat32_t sumv1 = svdup_n_f32(0.0f);
237
238 const int vector_length = ggml_cpu_get_sve_cnt()*8;
239
240 // VLA Implementation using switch case
241 switch (vector_length) {
242 case 128:
243 {
244 // predicate for activating higher lanes for 4 float32 elements
245 const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
246
247 for (; ib + 1 < nb; ib += 2) {
248 const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
249 const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
250 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
251 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
252
253 // load x
254 const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
255 const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
256
257 // 4-bit -> 8-bit
258 const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
259 const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
260 const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
261 const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
262
263 // sub 8
264 const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
265 const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
266 const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
267 const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
268
269 // load y
270 const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
271 const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
272 const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
273 const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
274
275 // dot product
276 sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
277 svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
278 svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
279 sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
280 svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
281 svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
282 }
283
284 sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
285 } break;
286 case 256:
287 {
288 // predicate for activating higher lanes for 16 int8 elements
289 const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
290 // predicate for activating lower lanes for 16 int8 elements
291 const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
292
293 for (; ib + 1 < nb; ib += 2) {
294 const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
295 const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
296 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
297 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
298
299 // load x
300 const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
301 const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
302
303 // 4-bit -> 8-bit
304 const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
305 const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
306
307 // sub 8
308 const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
309 const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
310
311 // load y
312 const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
313 const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
314
315 // dot product
316 sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
317 svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
318 sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
319 svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
320 }
321
322 sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
323 } break;
324 case 512:
325 {
326 // predicate for activating higher lanes for 32 int8 elements
327 const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
328
329 // predicate for activating higher lanes for 16 int8 elements
330 const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
331 // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
332 const svbool_t pl16 = svnot_b_z(ph32, ph16);
333
334 for (; ib + 1 < nb; ib += 2) {
335 const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
336 const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
337 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
338 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
339
340 // load x
341 const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
342 const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
343
344 // 4-bit -> 8-bit
345 const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
346 const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
347
348 // sub 8
349 const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
350 const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
351
352 // load y
353 const svint8_t qy0 = svld1_s8(ph32, y0->qs);
354 const svint8_t qy1 = svld1_s8(ph32, y1->qs);
355
356 // dot product
357 sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
358 svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
359 sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
360 svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
361 }
362
363 sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
364 } break;
365 default:
366 assert(false && "Unsupported vector length");
367 break;
368 }
369
370#elif defined(__ARM_NEON)
371 float32x4_t sumv0 = vdupq_n_f32(0.0f);
372 float32x4_t sumv1 = vdupq_n_f32(0.0f);
373
374 for (; ib + 1 < nb; ib += 2) {
375 const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];
376 const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
377 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
378 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
379
380 const uint8x16_t m4b = vdupq_n_u8(0x0F);
381 const int8x16_t s8b = vdupq_n_s8(0x8);
382
383 const uint8x16_t v0_0 = vld1q_u8(x0->qs);
384 const uint8x16_t v0_1 = vld1q_u8(x1->qs);
385
386 // 4-bit -> 8-bit
387 const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
388 const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
389 const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
390 const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
391
392 // sub 8
393 const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
394 const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
395 const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
396 const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
397
398 // load y
399 const int8x16_t v1_0l = vld1q_s8(y0->qs);
400 const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
401 const int8x16_t v1_1l = vld1q_s8(y1->qs);
402 const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
403
404 // dot product into int32x4_t
405 const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
406 const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
407
408 sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
409 sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
410 }
411
412 sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
413#endif
414 for (; ib < nb; ++ib) {
415 int sumi0 = 0;
416 int sumi1 = 0;
417
418 for (int j = 0; j < qk/2; ++j) {
419 const int v0 = (x[ib].qs[j] & 0x0F) - 8;
420 const int v1 = (x[ib].qs[j] >> 4) - 8;
421
422 sumi0 += (v0 * y[ib].qs[j]);
423 sumi1 += (v1 * y[ib].qs[j + qk/2]);
424 }
425
426 int sumi = sumi0 + sumi1;
427 sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
428 }
429
430 *s = sumf;
431}
432
433void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
434 const int qk = QK8_1;
435 const int nb = n / qk;
436
437 assert(n % qk == 0);
438#if defined(__ARM_FEATURE_MATMUL_INT8)
439 assert((nrc == 2) || (nrc == 1));
440#else
441 assert(nrc == 1);
442#endif
443 UNUSED(nrc);
444 UNUSED(bx);
445 UNUSED(by);
446 UNUSED(bs);
447
448 const block_q4_1 * GGML_RESTRICT x = vx;
449 const block_q8_1 * GGML_RESTRICT y = vy;
450
451#if defined(__ARM_FEATURE_MATMUL_INT8)
452 if (nrc == 2) {
453 const block_q4_1 * GGML_RESTRICT vx0 = vx;
454 const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
455 const block_q8_1 * GGML_RESTRICT vy0 = vy;
456 const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
457
458 float32x4_t sumv0 = vdupq_n_f32(0.0f);
459 float32x4_t summs0 = vdupq_n_f32(0.0f);
460
461 for (int i = 0; i < nb; i++) {
462 const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i];
463 const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i];
464 const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i];
465 const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i];
466
467 float32_t summs_t[4] = {
468 GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y0->s),
469 GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y0->s),
470 GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y1->s),
471 GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y1->s)
472 };
473 summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
474
475 const uint8x16_t m4b = vdupq_n_u8(0x0F);
476
477 const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
478 const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
479
480 // 4-bit -> 8-bit
481 const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
482 const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
483 const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
484 const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
485
486 // load y
487 const int8x16_t y0_l = vld1q_s8(b_y0->qs);
488 const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
489 const int8x16_t y1_l = vld1q_s8(b_y1->qs);
490 const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
491
492 // mmla into int32x4_t
493 float32_t _scale[4] = {
494 GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
495 GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),
496 GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
497 GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)
498 };
499 float32x4_t scale = vld1q_f32(_scale);
500
501 int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
502 int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
503
504 int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
505 int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
506
507 int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
508 int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
509
510 int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
511 int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
512 sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
513 l1, r1)), l2, r2)), l3, r3))), scale);
514 }
515
516 float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
517 float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
518
519 sumv2 = vaddq_f32(sumv2, summs0);
520
521 vst1_f32(s, vget_low_f32 (sumv2));
522 vst1_f32(s + bs, vget_high_f32(sumv2));
523
524 return;
525 }
526#endif
527
528 int ib = 0;
529 float sumf = 0;
530
531#if defined(__ARM_NEON)
532 float32x4_t sumv0 = vdupq_n_f32(0.0f);
533 float32x4_t sumv1 = vdupq_n_f32(0.0f);
534
535 float summs = 0;
536
537 for (; ib + 1 < nb; ib += 2) {
538 const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0];
539 const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1];
540 const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
541 const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
542
543 summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s) + GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
544
545 const uint8x16_t m4b = vdupq_n_u8(0x0F);
546
547 const uint8x16_t v0_0 = vld1q_u8(x0->qs);
548 const uint8x16_t v0_1 = vld1q_u8(x1->qs);
549
550 // 4-bit -> 8-bit
551 const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
552 const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
553 const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
554 const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
555
556 // load y
557 const int8x16_t v1_0l = vld1q_s8(y0->qs);
558 const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
559 const int8x16_t v1_1l = vld1q_s8(y1->qs);
560 const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
561
562 // dot product into int32x4_t
563 const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
564 const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
565
566 sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
567 sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
568 }
569
570 sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
571
572#endif
573 for (; ib < nb; ++ib) {
574 int sumi0 = 0;
575 int sumi1 = 0;
576
577 for (int j = 0; j < qk/2; ++j) {
578 const int v0 = (x[ib].qs[j] & 0x0F);
579 const int v1 = (x[ib].qs[j] >> 4);
580
581 sumi0 += (v0 * y[ib].qs[j]);
582 sumi1 += (v1 * y[ib].qs[j + qk/2]);
583 }
584
585 int sumi = sumi0 + sumi1;
586 sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
587 }
588
589 *s = sumf;
590}
591
592void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
593 assert(nrc == 1);
594 UNUSED(nrc);
595 UNUSED(bx);
596 UNUSED(by);
597 UNUSED(bs);
598 assert(n % QK_MXFP4 == 0);
599 static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
600
601 const block_mxfp4 * GGML_RESTRICT x = vx;
602 const block_q8_0 * GGML_RESTRICT y = vy;
603
604 const int nb = n / QK_MXFP4;
605
606 int ib = 0;
607 float sumf = 0;
608
609#if defined __ARM_NEON
610 const int8x16_t values = vld1q_s8(kvalues_mxfp4);
611 const uint8x16_t m4b = vdupq_n_u8(0x0f);
612 uint8x16x2_t q4bits;
613 int8x16x4_t q4b;
614 int8x16x4_t q8b;
615 int32x4_t prod_1;
616 int32x4_t prod_2;
617
618 for (; ib + 1 < nb; ib += 2) {
619 q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
620 q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
621 q8b.val[0] = vld1q_s8(y[ib + 0].qs);
622 q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
623 q8b.val[2] = vld1q_s8(y[ib + 1].qs);
624 q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
625
626 q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
627 q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
628 q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
629 q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
630
631 prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
632 prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
633
634 sumf +=
635 GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
636 GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
637 }
638
639#endif
640 for (; ib < nb; ++ib) {
641 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
642 int sumi1 = 0;
643 int sumi2 = 0;
644 for (int j = 0; j < QK_MXFP4/2; ++j) {
645 sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
646 sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
647 }
648 sumf += d * (sumi1 + sumi2);
649 }
650 *s = sumf;
651}
652
653void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
654 const int qk = QK8_0;
655 const int nb = n / qk;
656
657 int ib = 0;
658 float sumf = 0;
659
660 assert(n % qk == 0);
661 assert(qk == QK5_0);
662 assert(nrc == 1);
663 UNUSED(nrc);
664 UNUSED(bx);
665 UNUSED(by);
666 UNUSED(bs);
667
668 const block_q5_0 * GGML_RESTRICT x = vx;
669 const block_q8_0 * GGML_RESTRICT y = vy;
670
671#if defined(__ARM_NEON)
672 float32x4_t sumv0 = vdupq_n_f32(0.0f);
673 float32x4_t sumv1 = vdupq_n_f32(0.0f);
674
675 uint32_t qh0;
676 uint32_t qh1;
677
678 uint64_t tmp0[4];
679 uint64_t tmp1[4];
680
681 for (; ib + 1 < nb; ib += 2) {
682 const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
683 const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
684 const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
685 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
686
687 const uint8x16_t m4b = vdupq_n_u8(0x0F);
688
689 // extract the 5th bit via lookup table ((!b) << 4)
690 memcpy(&qh0, x0->qh, sizeof(qh0));
691 memcpy(&qh1, x1->qh, sizeof(qh1));
692
693 tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
694 tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
695 tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
696 tmp0[3] = table_b2b_1[(qh0 >> 24) ];
697
698 tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
699 tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
700 tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
701 tmp1[3] = table_b2b_1[(qh1 >> 24) ];
702
703 const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
704 const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
705 const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
706 const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
707
708 const uint8x16_t v0_0 = vld1q_u8(x0->qs);
709 const uint8x16_t v0_1 = vld1q_u8(x1->qs);
710
711 // 4-bit -> 8-bit
712 int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
713 int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
714 int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
715 int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
716
717 // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
718 const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
719 const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
720 const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
721 const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
722
723 // load y
724 const int8x16_t v1_0l = vld1q_s8(y0->qs);
725 const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
726 const int8x16_t v1_1l = vld1q_s8(y1->qs);
727 const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
728
729 sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
730 ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
731 ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
732 sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
733 ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
734 ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
735 }
736
737 sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
738
739#endif
740 for (; ib < nb; ++ib) {
741 uint32_t qh;
742 memcpy(&qh, x[ib].qh, sizeof(qh));
743
744 int sumi0 = 0;
745 int sumi1 = 0;
746
747 for (int j = 0; j < qk/2; ++j) {
748 const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
749 const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
750
751 const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
752 const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
753
754 sumi0 += (x0 * y[ib].qs[j]);
755 sumi1 += (x1 * y[ib].qs[j + qk/2]);
756 }
757
758 int sumi = sumi0 + sumi1;
759 sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
760 }
761
762 *s = sumf;
763}
764
765void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
766 const int qk = QK8_1;
767 const int nb = n / qk;
768
769 int ib = 0;
770 float sumf = 0;
771
772 assert(n % qk == 0);
773 assert(qk == QK5_1);
774 assert(nrc == 1);
775 UNUSED(nrc);
776 UNUSED(bx);
777 UNUSED(by);
778 UNUSED(bs);
779
780 const block_q5_1 * GGML_RESTRICT x = vx;
781 const block_q8_1 * GGML_RESTRICT y = vy;
782
783#if defined(__ARM_NEON)
784 float32x4_t sumv0 = vdupq_n_f32(0.0f);
785 float32x4_t sumv1 = vdupq_n_f32(0.0f);
786
787 float summs0 = 0.0f;
788 float summs1 = 0.0f;
789
790 uint32_t qh0;
791 uint32_t qh1;
792
793 uint64_t tmp0[4];
794 uint64_t tmp1[4];
795
796 for (; ib + 1 < nb; ib += 2) {
797 const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
798 const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
799 const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
800 const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
801
802 const uint8x16_t m4b = vdupq_n_u8(0x0F);
803
804 summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
805 summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
806
807 // extract the 5th bit via lookup table ((b) << 4)
808 memcpy(&qh0, x0->qh, sizeof(qh0));
809 memcpy(&qh1, x1->qh, sizeof(qh1));
810
811 tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
812 tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
813 tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
814 tmp0[3] = table_b2b_0[(qh0 >> 24) ];
815
816 tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
817 tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
818 tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
819 tmp1[3] = table_b2b_0[(qh1 >> 24) ];
820
821 const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
822 const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
823 const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
824 const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
825
826 const uint8x16_t v0_0 = vld1q_u8(x0->qs);
827 const uint8x16_t v0_1 = vld1q_u8(x1->qs);
828
829 // 4-bit -> 8-bit
830 const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
831 const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
832 const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
833 const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
834
835 // add high bit
836 const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
837 const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
838 const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
839 const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
840
841 // load y
842 const int8x16_t v1_0l = vld1q_s8(y0->qs);
843 const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
844 const int8x16_t v1_1l = vld1q_s8(y1->qs);
845 const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
846
847 sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
848 ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
849 ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
850 sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
851 ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
852 ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
853 }
854
855 sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
856
857#endif
858 for (; ib < nb; ++ib) {
859 uint32_t qh;
860 memcpy(&qh, x[ib].qh, sizeof(qh));
861
862 int sumi0 = 0;
863 int sumi1 = 0;
864
865 for (int j = 0; j < qk/2; ++j) {
866 const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
867 const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
868
869 const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
870 const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
871
872 sumi0 += (x0 * y[ib].qs[j]);
873 sumi1 += (x1 * y[ib].qs[j + qk/2]);
874 }
875
876 int sumi = sumi0 + sumi1;
877 sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
878 }
879
880 *s = sumf;
881}
882
883void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
884 const int qk = QK8_0;
885 const int nb = n / qk;
886
887 assert(n % qk == 0);
888#if defined(__ARM_FEATURE_MATMUL_INT8)
889 assert((nrc == 2) || (nrc == 1));
890#else
891 assert(nrc == 1);
892#endif
893 UNUSED(nrc);
894 UNUSED(bx);
895 UNUSED(by);
896 UNUSED(bs);
897
898 const block_q8_0 * GGML_RESTRICT x = vx;
899 const block_q8_0 * GGML_RESTRICT y = vy;
900
901#if defined(__ARM_FEATURE_MATMUL_INT8)
902 if (nrc == 2) {
903 const block_q8_0 * GGML_RESTRICT vx0 = vx;
904 const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
905 const block_q8_0 * GGML_RESTRICT vy0 = vy;
906 const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
907
908 float32x4_t sumv0 = vdupq_n_f32(0.0f);
909
910 for (int i = 0; i < nb; i++) {
911 const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i];
912 const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];
913
914 const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i];
915 const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];
916
917 const int8x16_t x0_l = vld1q_s8(b_x0->qs);
918 const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
919 const int8x16_t x1_l = vld1q_s8(b_x1->qs);
920 const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
921
922 // load y
923 const int8x16_t y0_l = vld1q_s8(b_y0->qs);
924 const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
925 const int8x16_t y1_l = vld1q_s8(b_y1->qs);
926 const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
927
928 float32_t _scale[4] = {
929 GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
930 GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),
931 GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),
932 GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)
933 };
934 float32x4_t scale = vld1q_f32(_scale);
935
936 int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
937 int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
938
939 int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
940 int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
941
942 int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
943 int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
944
945 int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
946 int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
947
948 sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
949 l1, r1)), l2, r2)), l3, r3))), scale);
950 }
951
952 float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
953 float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
954
955 vst1_f32(s, vget_low_f32 (sumv2));
956 vst1_f32(s + bs, vget_high_f32(sumv2));
957
958 return;
959 }
960#endif
961
962 int ib = 0;
963 float sumf = 0;
964
965#if defined(__ARM_FEATURE_SVE)
966 svfloat32_t sumv0 = svdup_n_f32(0.0f);
967 svfloat32_t sumv1 = svdup_n_f32(0.0f);
968
969 const int vector_length = ggml_cpu_get_sve_cnt()*8;
970
971 //VLA Implemenation for SVE
972 switch (vector_length) {
973 case 128:
974 {
975 // predicate for activating lanes for 16 Int8 elements
976 const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
977 const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
978
979 for (; ib + 1 < nb; ib += 2) {
980 const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
981 const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
982 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
983 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
984
985 // load x
986 const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
987 const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
988 const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
989 const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
990
991 // load y
992 const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
993 const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
994 const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
995 const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
996
997 sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
998 svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
999 svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
1000 sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
1001 svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
1002 svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
1003 }
1004
1005 sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
1006 } break;
1007 case 256:
1008 {
1009 //printf("sve256");
1010 for (; ib + 1 < nb; ib += 2) {
1011 const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
1012 const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
1013 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
1014 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
1015
1016 // load x
1017 const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
1018 const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
1019
1020 // load y
1021 const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
1022 const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
1023
1024 sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
1025 svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
1026 sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
1027 svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
1028 }
1029
1030 sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
1031 } break;
1032 case 512:
1033 {
1034 // predicate for activating high 256 bit
1035 const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
1036 // predicate for activating low 256 bit
1037 const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
1038
1039 // predicate for activating high lanes for 8 float32 elements
1040 const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
1041 // predicate for activating low lanes for 8 float32 elements
1042 const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
1043
1044 svfloat32_t sumv00 = svdup_n_f32(0.0f);
1045
1046 for (; ib + 1 < nb; ib += 2) {
1047 const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
1048 const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
1049 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
1050 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
1051
1052 //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
1053 // and add them to make one 64 element vector
1054 // load x
1055 const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
1056 svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
1057
1058 qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
1059
1060 // load y
1061 const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
1062 svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
1063
1064 qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
1065
1066 // scale creation
1067 const float32_t deq1 = GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d);
1068 const float32_t deq2 = GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d);
1069
1070 // duplicate deq1 in first half of vector and deq2 in second half of vector
1071 const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
1072
1073 const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
1074
1075 sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
1076 }
1077
1078 sumf = svaddv_f32(svptrue_b32(), sumv00);
1079 break;
1080 }
1081 default:
1082 assert(false && "Unsupported vector length");
1083 break;
1084 }
1085#elif defined(__ARM_NEON)
1086 float32x4_t sumv0 = vdupq_n_f32(0.0f);
1087 float32x4_t sumv1 = vdupq_n_f32(0.0f);
1088
1089 for (; ib + 1 < nb; ib += 2) {
1090 const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];
1091 const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];
1092 const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
1093 const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
1094
1095 const int8x16_t x0_0 = vld1q_s8(x0->qs);
1096 const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
1097 const int8x16_t x1_0 = vld1q_s8(x1->qs);
1098 const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
1099
1100 // load y
1101 const int8x16_t y0_0 = vld1q_s8(y0->qs);
1102 const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
1103 const int8x16_t y1_0 = vld1q_s8(y1->qs);
1104 const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
1105
1106 sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
1107 ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
1108 ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));
1109
1110 sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
1111 ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
1112 ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));
1113 }
1114
1115 sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
1116#endif
1117 for (; ib < nb; ++ib) {
1118 int sumi = 0;
1119
1120 for (int j = 0; j < qk; j++) {
1121 sumi += x[ib].qs[j]*y[ib].qs[j];
1122 }
1123
1124 sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
1125 }
1126
1127 *s = sumf;
1128}
1129
1130void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1131 assert(nrc == 1);
1132 UNUSED(nrc);
1133 UNUSED(bx);
1134 UNUSED(by);
1135 UNUSED(bs);
1136
1137 const block_tq1_0 * GGML_RESTRICT x = vx;
1138 const block_q8_K * GGML_RESTRICT y = vy;
1139
1140 const int nb = n / QK_K;
1141
1142#if defined(__ARM_NEON)
1143 float sumf = 0.0f;
1144
1145 uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
1146
1147 const uint8x16_t shift = vld1q_u8(k_shift);
1148
1149 for (int i = 0; i < nb; ++i) {
1150#if defined(__ARM_FEATURE_DOTPROD)
1151 int32x4_t sumi0 = vdupq_n_s32(0);
1152 int32x4_t sumi1 = vdupq_n_s32(0);
1153#else
1154 int16x8_t sumi0 = vdupq_n_s16(0);
1155 int16x8_t sumi1 = vdupq_n_s16(0);
1156#endif
1157
1158 // first 32 bytes of 5 elements
1159 {
1160 uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
1161 uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
1162 uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
1163 uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
1164 uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
1165 uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
1166 uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
1167 uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
1168 uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
1169 uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
1170
1171 // multiply by 3 and keep the 2 bits above 8 bits
1172 int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
1173 int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
1174 int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
1175 int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
1176 int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
1177 int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
1178 int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
1179 int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
1180 int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
1181 int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
1182
1183 const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
1184 const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
1185 const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
1186 const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
1187 const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
1188 const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
1189 const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
1190 const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
1191 const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
1192 const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
1193
1194#if defined(__ARM_FEATURE_DOTPROD)
1195 sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1196 sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1197 sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1198 sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1199 sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1200 sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1201 sumi0 = vdotq_s32(sumi0, sqx6, qy6);
1202 sumi1 = vdotq_s32(sumi1, sqx7, qy7);
1203 sumi0 = vdotq_s32(sumi0, sqx8, qy8);
1204 sumi1 = vdotq_s32(sumi1, sqx9, qy9);
1205#else
1206 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1207 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1208 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1209 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1210 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1211 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1212 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1213 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1214 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1215 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1216 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1217 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1218 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
1219 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
1220 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
1221 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
1222 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
1223 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
1224 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
1225 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
1226#endif
1227 }
1228
1229 // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1230 {
1231 uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
1232 uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
1233 uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
1234 uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
1235 uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
1236 uint32_t qh;
1237 memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
1238 uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
1239 qx5 = vmulq_u8(qx5, shift);
1240
1241 // multiply by 3 and keep the 2 bits above 8 bits
1242 int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
1243 int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
1244 int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
1245 int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
1246 int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
1247 int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
1248
1249 const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
1250 const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
1251 const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
1252 const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
1253 const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
1254 const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
1255
1256#if defined(__ARM_FEATURE_DOTPROD)
1257 sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1258 sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1259 sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1260 sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1261 sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1262 sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1263#else
1264 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1265 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1266 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1267 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1268 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1269 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1270 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1271 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1272 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1273 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1274 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1275 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1276#endif
1277 }
1278
1279 const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
1280 const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
1281
1282 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1283
1284#if defined(__ARM_FEATURE_DOTPROD)
1285 sumi0 = vaddq_s32(sumi0, sumi1);
1286 sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
1287
1288 sumf += d * (float) vaddvq_s32(sumi0);
1289#else
1290 sumi0 = vaddq_s16(sumi0, sumi1);
1291 sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
1292
1293 sumf += d * (float) vaddlvq_s16(sumi0);
1294#endif
1295 }
1296
1297 *s = sumf;
1298
1299#else
1300 UNUSED(x);
1301 UNUSED(y);
1302 UNUSED(nb);
1303 ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1304#endif
1305}
1306
1307void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1308 assert(nrc == 1);
1309 UNUSED(nrc);
1310 UNUSED(bx);
1311 UNUSED(by);
1312 UNUSED(bs);
1313
1314 const block_tq2_0 * GGML_RESTRICT x = vx;
1315 const block_q8_K * GGML_RESTRICT y = vy;
1316
1317 const int nb = n / QK_K;
1318
1319#if defined(__ARM_NEON)
1320 float sumf = 0.0f;
1321
1322 const uint8x16_t m3 = vdupq_n_u8(3);
1323
1324 for (int i = 0; i < nb; ++i) {
1325#if defined(__ARM_FEATURE_DOTPROD)
1326 int32x4_t sumi0 = vdupq_n_s32(0);
1327 int32x4_t sumi1 = vdupq_n_s32(0);
1328#else
1329 int16x8_t sumi0 = vdupq_n_s16(0);
1330 int16x8_t sumi1 = vdupq_n_s16(0);
1331#endif
1332
1333 for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1334 uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
1335 uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
1336 uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
1337 uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
1338 uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
1339 uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
1340 uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
1341 uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
1342
1343 int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
1344 int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
1345 int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
1346 int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
1347 int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
1348 int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
1349 int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
1350 int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
1351
1352 const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
1353 const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
1354 const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
1355 const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
1356 const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
1357 const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
1358 const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
1359 const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
1360
1361#if defined(__ARM_FEATURE_DOTPROD)
1362 sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1363 sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1364 sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1365 sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1366 sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1367 sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1368 sumi0 = vdotq_s32(sumi0, sqx6, qy6);
1369 sumi1 = vdotq_s32(sumi1, sqx7, qy7);
1370#else
1371 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1372 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1373 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1374 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1375 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1376 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1377 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1378 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1379 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1380 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1381 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1382 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1383 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
1384 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
1385 sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
1386 sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
1387#endif
1388 }
1389
1390 const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
1391 const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
1392
1393 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1394
1395#if defined(__ARM_FEATURE_DOTPROD)
1396 sumi0 = vaddq_s32(sumi0, sumi1);
1397 sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
1398
1399 sumf += d * (float) vaddvq_s32(sumi0);
1400#else
1401 sumi0 = vaddq_s16(sumi0, sumi1);
1402 sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
1403
1404 sumf += d * (float) vaddlvq_s16(sumi0);
1405#endif
1406 }
1407
1408 *s = sumf;
1409
1410#else
1411 UNUSED(x);
1412 UNUSED(y);
1413 UNUSED(nb);
1414 ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1415#endif
1416}
1417
1418void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1419 assert(nrc == 1);
1420 UNUSED(nrc);
1421 UNUSED(bx);
1422 UNUSED(by);
1423 UNUSED(bs);
1424
1425 const block_q2_K * GGML_RESTRICT x = vx;
1426 const block_q8_K * GGML_RESTRICT y = vy;
1427
1428 const int nb = n / QK_K;
1429
1430#ifdef __ARM_FEATURE_SVE
1431 const int vector_length = svcntb()*8;
1432 const svuint8_t m3s = svdup_n_u8(0x3);
1433 const svuint32_t m4s = svdup_n_u32(0xF);
1434 const svint32_t vzero_sv = svdup_n_s32(0);
1435 svfloat32_t acc_sum = svdup_n_f32(0);
1436 svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
1437
1438 switch (vector_length) {
1439 case 128:
1440 for (int i = 0; i < nb; ++i) {
1441 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1442 svfloat32_t d_broad = svdup_n_f32((float32_t)d);
1443 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1444 svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
1445
1446 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1447 const int8_t * GGML_RESTRICT q8_sv = y[i].qs;
1448 const uint8_t * GGML_RESTRICT sc = x[i].scales;
1449
1450 svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
1451 const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1452
1453 mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
1454 const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1455
1456 svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
1457 svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
1458
1459 const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
1460
1461 mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
1462 const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1463
1464 mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
1465 const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1466
1467 q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
1468 q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
1469
1470 svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
1471
1472 svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
1473
1474 acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
1475
1476 svint32_t sumi1 = svdup_n_s32(0);
1477
1478 {
1479 const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
1480 svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
1481 svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1482 const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
1483
1484 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
1485
1486 const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
1487 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
1488 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1489
1490 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
1491
1492 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
1493 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1494
1495 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
1496
1497 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
1498 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1499
1500 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
1501
1502
1503 const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
1504
1505 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
1506 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1507
1508 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
1509
1510 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
1511 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1512
1513 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
1514
1515 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
1516 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1517
1518 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
1519
1520 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
1521 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1522
1523 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
1524
1525 //-------------------------------
1526
1527 q2 += 32;
1528 const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
1529 const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
1530
1531 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
1532 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1533
1534 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
1535
1536 const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
1537 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
1538 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1539
1540 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
1541
1542
1543 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
1544 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1545
1546 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
1547
1548 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
1549 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1550
1551 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
1552
1553
1554 const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
1555
1556 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
1557 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1558
1559 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
1560
1561 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
1562 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1563
1564 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
1565
1566
1567
1568 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
1569 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1570
1571 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
1572
1573 q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
1574 q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1575
1576 sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
1577 }
1578 acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
1579 }
1580 *s = svaddv_f32(svptrue_b32(), acc_sum);
1581 break;
1582
1583 case 256:
1584 case 512:
1585 for (int i = 0; i < nb; ++i) {
1586 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1587 svfloat32_t d_broad = svdup_n_f32((float32_t)d);
1588 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1589 svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
1590
1591 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1592 const int8_t * GGML_RESTRICT q8_sv = y[i].qs;
1593 const uint8_t * GGML_RESTRICT sc = x[i].scales;
1594
1595 const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
1596 const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
1597 const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
1598 svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
1599
1600 const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
1601 const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
1602 const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
1603
1604 svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
1605
1606 svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
1607
1608 acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
1609
1610 svint32_t sumi1 = svdup_n_s32(0);
1611
1612 {
1613 const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
1614 svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
1615 svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1616
1617 svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
1618 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1619
1620 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
1621 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1622
1623 svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
1624 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
1625
1626 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
1627 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1628
1629 scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
1630 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1631
1632 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
1633 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1634
1635 scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
1636 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1637
1638 q2 += 32;
1639
1640 const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
1641 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
1642 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1643
1644 scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
1645 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1646
1647 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
1648 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1649
1650 scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
1651 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1652
1653 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
1654 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1655
1656 scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
1657 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1658
1659 q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
1660 q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1661
1662 scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
1663 sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1664 }
1665 acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
1666 }
1667 *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
1668 break;
1669
1670 default:
1671 assert(false && "Unsupported vector length");
1672 break;
1673 }
1674
1675#elif __ARM_NEON
1676 const uint8x16_t m3 = vdupq_n_u8(0x3);
1677 const uint8x16_t m4 = vdupq_n_u8(0xF);
1678
1679 const int32x4_t vzero = vdupq_n_s32(0);
1680
1681 ggml_int8x16x2_t q2bytes;
1682 uint8_t aux[16];
1683
1684 float sum = 0;
1685
1686 for (int i = 0; i < nb; ++i) {
1687 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1688 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1689
1690 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1691 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1692 const uint8_t * GGML_RESTRICT sc = x[i].scales;
1693
1694 const uint8x16_t mins_and_scales = vld1q_u8(sc);
1695 const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
1696 vst1q_u8(aux, scales);
1697
1698 const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
1699 const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
1700 const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
1701 const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
1702 vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
1703 const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
1704 vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
1705 sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
1706
1707 int isum = 0;
1708 int is = 0;
1709
1710// We use this macro instead of a function call because for some reason
1711// the code runs 2-3% slower, even if the function is declared inline
1712#define MULTIPLY_ACCUM_WITH_SCALE(index)\
1713 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
1714 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
1715
1716#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
1717 q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
1718 q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
1719 q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
1720 MULTIPLY_ACCUM_WITH_SCALE((index));
1721
1722 for (int j = 0; j < QK_K/128; ++j) {
1723 const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
1724
1725 ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
1726 q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
1727 q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
1728
1729 MULTIPLY_ACCUM_WITH_SCALE(0);
1730
1731 SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
1732 SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
1733 SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
1734
1735 is += 8;
1736 }
1737
1738 sum += d * isum;
1739 }
1740
1741 *s = sum;
1742
1743#else
1744 UNUSED(x);
1745 UNUSED(y);
1746 UNUSED(nb);
1747 ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1748#endif
1749}
1750
1751void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1752 assert(n % QK_K == 0);
1753 assert(nrc == 1);
1754 UNUSED(nrc);
1755 UNUSED(bx);
1756 UNUSED(by);
1757 UNUSED(bs);
1758
1759 const uint32_t kmask1 = 0x03030303;
1760 const uint32_t kmask2 = 0x0f0f0f0f;
1761
1762 const block_q3_K * GGML_RESTRICT x = vx;
1763 const block_q8_K * GGML_RESTRICT y = vy;
1764
1765 const int nb = n / QK_K;
1766
1767#if defined(__ARM_FEATURE_SVE)
1768
1769 uint32_t aux[3];
1770 uint32_t utmp[4];
1771
1772 const int8_t m32 = 32;
1773 const int vector_length = svcntb()*8;
1774 const svuint8_t m3b_sv = svdup_n_u8(0x3);
1775 const svint32_t vzero_sv = svdup_n_s32(0);
1776
1777 const svuint8_t m0_sv = svdup_n_u8(1);
1778 const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
1779 const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
1780 const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
1781
1782 float sum = 0;
1783
1784 for (int i = 0; i < nb; ++i) {
1785
1786 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1787
1788 const uint8_t * GGML_RESTRICT q3_sv = x[i].qs;
1789 const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask;
1790 const int8_t * GGML_RESTRICT q8_sv = y[i].qs;
1791
1792 // Set up scales
1793 memcpy(aux, x[i].scales, 12);
1794 utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
1795 utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
1796 utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
1797 utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
1798
1799 int8_t * scale = (int8_t *)utmp;
1800
1801 for (int j = 0; j < 16; ++j) scale[j] -= m32;
1802
1803 switch (vector_length) {
1804 case 128:
1805 {
1806 svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
1807 svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
1808 svuint8_t q3h_sv;
1809
1810 svint32_t sumi1_1 = svdup_n_s32(0);
1811 svint8_t q3bytes_sv;
1812
1813 for (int j = 0; j < QK_K/128; ++j) {
1814
1815 const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
1816 const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
1817 svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1818 svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1819
1820 q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
1821 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1822
1823 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
1824
1825 q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
1826 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1827
1828 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
1829
1830 q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1831 q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1832
1833 q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
1834 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1835
1836 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
1837
1838 q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
1839 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1840
1841 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
1842
1843
1844 scale += 4;
1845 q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1846 q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1847
1848 q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
1849 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1850
1851 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
1852
1853 q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
1854 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1855
1856 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
1857
1858
1859 q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1860 q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1861
1862 q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
1863 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1864
1865 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
1866
1867 q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
1868 q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1869
1870 sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
1871
1872 if (j == 0) {
1873 qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
1874 qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
1875 }
1876
1877 scale += 4;
1878 }
1879
1880 sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
1881 } break;
1882 case 256:
1883 case 512:
1884 {
1885 svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
1886 svuint8_t q3h_sv;
1887
1888 svint32_t sumi1_1 = svdup_n_s32(0);
1889 svint8_t q3bytes_sv;
1890
1891 for (int j = 0; j < QK_K/128; ++j) {
1892
1893 const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
1894 svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1895 svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1896
1897 q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
1898 q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1899
1900
1901 svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
1902 sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
1903
1904 q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
1905 q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1906
1907 scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
1908 sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
1909
1910 scale += 4;
1911 q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1912 q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1913
1914 q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
1915 q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1916
1917 scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
1918 sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
1919
1920 q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
1921 q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1922
1923 scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
1924 sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
1925
1926 if (j == 0) {
1927 qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
1928 }
1929
1930 scale += 4;
1931 }
1932
1933 sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
1934 } break;
1935 default:
1936 assert(false && "Unsupported vector length");
1937 break;
1938 }
1939 }
1940 *s = sum;
1941
1942#elif __ARM_NEON
1943
1944 uint32_t aux[3];
1945 uint32_t utmp[4];
1946
1947 const uint8x16_t m3b = vdupq_n_u8(0x3);
1948 const int32x4_t vzero = vdupq_n_s32(0);
1949
1950 const uint8x16_t m0 = vdupq_n_u8(1);
1951 const uint8x16_t m1 = vshlq_n_u8(m0, 1);
1952 const uint8x16_t m2 = vshlq_n_u8(m0, 2);
1953 const uint8x16_t m3 = vshlq_n_u8(m0, 3);
1954 const int8_t m32 = 32;
1955
1956 ggml_int8x16x4_t q3bytes;
1957
1958 float sum = 0;
1959
1960 for (int i = 0; i < nb; ++i) {
1961
1962 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1963
1964 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1965 const uint8_t * GGML_RESTRICT qh = x[i].hmask;
1966 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1967
1968 ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
1969
1970 ggml_uint8x16x4_t q3h;
1971
1972 int32_t isum = 0;
1973
1974 // Set up scales
1975 memcpy(aux, x[i].scales, 12);
1976 utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
1977 utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
1978 utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
1979 utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
1980
1981 int8_t * scale = (int8_t *)utmp;
1982 for (int j = 0; j < 16; ++j) scale[j] -= m32;
1983
1984 for (int j = 0; j < QK_K/128; ++j) {
1985
1986 const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
1987 const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
1988 const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
1989
1990 q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
1991 q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
1992 q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
1993 q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
1994
1995 q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
1996 q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
1997 q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
1998 q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
1999
2000 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
2001 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
2002 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
2003 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
2004
2005 scale += 4;
2006
2007 q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
2008 q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
2009 q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
2010 q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
2011
2012 q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
2013 q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
2014 q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
2015 q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
2016
2017 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
2018 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
2019 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
2020 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
2021
2022 scale += 4;
2023
2024 if (j == 0) {
2025 qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
2026 qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
2027 }
2028
2029 }
2030 sum += d * isum;
2031
2032 }
2033
2034 *s = sum;
2035
2036#else
2037 UNUSED(kmask1);
2038 UNUSED(kmask2);
2039 UNUSED(x);
2040 UNUSED(y);
2041 UNUSED(nb);
2042 ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2043#endif
2044
2045}
2046
2047#ifdef __ARM_FEATURE_SVE
2048static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
2049 const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
2050 const svbool_t pg_false = svpfalse_b(); // 0x0000
2051 const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
2052 const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
2053
2054 svuint32_t vutmp_hi, vutmp_lo;
2055 svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
2056 vutmp_hi = svzip1_u32(vx01, vx01);
2057 vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
2058 vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
2059 const svuint32_t vx2 = svdup_u32(vx_scales[2]);
2060 vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
2061 vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
2062 svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
2063 return vutmp;
2064}
2065#endif
2066
2067void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2068 assert(n % QK_K == 0);
2069#ifdef __ARM_FEATURE_MATMUL_INT8
2070 assert((nrc == 2) || (nrc == 1));
2071#else
2072 assert(nrc == 1);
2073#endif
2074 UNUSED(nrc);
2075 UNUSED(bx);
2076 UNUSED(by);
2077 UNUSED(bs);
2078
2079 const block_q4_K * GGML_RESTRICT x = vx;
2080 const block_q8_K * GGML_RESTRICT y = vy;
2081
2082 const int nb = n / QK_K;
2083
2084 static const uint32_t kmask1 = 0x3f3f3f3f;
2085 static const uint32_t kmask2 = 0x0f0f0f0f;
2086 static const uint32_t kmask3 = 0x03030303;
2087
2088 uint32_t utmp[4];
2089#ifdef __ARM_FEATURE_SVE
2090 const int vector_length = ggml_cpu_get_sve_cnt()*8;
2091#endif
2092
2093#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2094 if (nrc == 2) {
2095 svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
2096
2097 const block_q4_K * GGML_RESTRICT vx0 = vx;
2098 const block_q8_K * GGML_RESTRICT vy0 = vy;
2099 const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
2100 const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
2101
2102 union {
2103 uint32_t u32[8];
2104 uint64_t u64[4];
2105 } new_utmp;
2106
2107 svfloat32_t sumf1 = svdup_n_f32(0);
2108
2109 switch (vector_length) {
2110 case 128:
2111 {
2112 svbool_t pg_false = svpfalse_b();
2113 svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
2114 svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
2115 svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
2116 svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
2117 for (int i = 0; i < nb; ++i) {
2118 svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2119 svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2120 svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
2121 svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
2122 svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2123 svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
2124 const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
2125 const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
2126 const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
2127 const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
2128 svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
2129 svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
2130 svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
2131 svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
2132 svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
2133 lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
2134 hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
2135 sum_tmp1 = svuzp1(lo, hi);
2136 sum_tmp2 = svuzp2(lo, hi);
2137 svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
2138 svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
2139 svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
2140 svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
2141 svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
2142 svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
2143 svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
2144 svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
2145 svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
2146 svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
2147 svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
2148 svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
2149 svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
2150 svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
2151 svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
2152 svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
2153 svint32_t svscales, sumi1, sumi2;
2154 svint32_t acc_sumif1 = svdup_n_s32(0);
2155 svint32_t acc_sumif2 = svdup_n_s32(0);
2156 svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
2157 q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
2158#pragma GCC unroll 1
2159 for (int j = 0; j < QK_K/64; ++j) {
2160 q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
2161 q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
2162 q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
2163 q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
2164 l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2165 l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2166 l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2167 l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2168 q8bytes_0_h = svld1_s8(pg128_all, q8_0);
2169 q8bytes_1_h = svld1_s8(pg128_all, q8_1);
2170 q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
2171 q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
2172 r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2173 r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2174 r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2175 r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2176 sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
2177 svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
2178 acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
2179
2180 q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
2181 q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
2182 q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
2183 q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
2184 l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2185 l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2186 l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2187 l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2188 q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
2189 q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
2190 q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
2191 q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
2192 r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2193 r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2194 r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2195 r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2196 sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
2197 svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
2198 acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
2199 q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
2200 }
2201 sumf1 = svmla_f32_x(pg128_all,
2202 svmla_f32_x(pg128_all,
2203 sumf1,
2204 svcvt_f32_x(pg128_all,
2205 svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
2206 svsuper_block_scales),
2207 svdmins,
2208 svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
2209 } //end of for nb
2210 } // end of case 128
2211 break;
2212 case 256:
2213 case 512:
2214 {
2215 const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2216 const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
2217 const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
2218 for (int i = 0; i < nb; ++i) {
2219 const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
2220 const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
2221 const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
2222 const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
2223 svint32_t svscales, sumi1, sumi2;
2224 svint32_t acc_sumif1 = svdup_n_s32(0);
2225 svint32_t acc_sumif2 = svdup_n_s32(0);
2226 svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
2227 svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2228 svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2229 svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
2230 svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
2231 svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
2232 svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2233 svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
2234 svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
2235 svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
2236 svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
2237 svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
2238 svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
2239 svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
2240 svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
2241 svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
2242 svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
2243 svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
2244 svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
2245 svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
2246 svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
2247 svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
2248 svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
2249 svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
2250 svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
2251 svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
2252
2253#pragma GCC unroll 1
2254 for (int j = 0; j < QK_K/64; ++j) {
2255 svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
2256 svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
2257 svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
2258 svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
2259 l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
2260 l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
2261 l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
2262 l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
2263 svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
2264 svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
2265 svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
2266 svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
2267 r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2268 r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2269 r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
2270 r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
2271 sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
2272 svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
2273 acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
2274 sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
2275 svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
2276 acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
2277 q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
2278 }
2279 svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
2280 svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
2281 acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
2282 sumf1 = svmla_f32_x(pg32_4,
2283 svmla_f32_x(pg32_4,
2284 sumf1,
2285 svcvt_f32_x(pg32_4, acc_sumif),
2286 svsuper_block_scales),
2287 svdmins,
2288 svsumfs_tmp);
2289 } // end of for nb
2290 } // end of case 256-512
2291 break;
2292 default:
2293 assert(false && "Unsupported vector length");
2294 break;
2295 }
2296
2297 svst1_f32(pg32_2, s, sumf1);
2298 svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
2299
2300 return;
2301 }
2302#elif defined(__ARM_FEATURE_MATMUL_INT8)
2303 if (nrc == 2) {
2304 const block_q4_K * GGML_RESTRICT x0 = x;
2305 const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
2306 const block_q8_K * GGML_RESTRICT y0 = y;
2307 const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
2308
2309 const uint8x16_t m4b = vdupq_n_u8(0x0f);
2310
2311 float32x4_t vfsum = vdupq_n_f32(0.0f);
2312
2313 for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
2314 const uint8_t * GGML_RESTRICT qx0 = x0->qs;
2315 const uint8_t * GGML_RESTRICT qx1 = x1->qs;
2316 const int8_t * GGML_RESTRICT qy0 = y0->qs;
2317 const int8_t * GGML_RESTRICT qy1 = y1->qs;
2318
2319 // decode scales and mins
2320 int8_t x0_scales[8], x1_scales[8];
2321 int16x8_t x0_mins, x1_mins;
2322 {
2323 uint32_t scales_mins[3];
2324 memcpy(scales_mins, x0->scales, 12);
2325 const uint32_t mins_0_3 = scales_mins[1] & kmask1;
2326 const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
2327 const uint32x2_t mins = {mins_0_3, mins_4_7};
2328 x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
2329 uint32_t scales[2];
2330 scales[0] = scales_mins[0] & kmask1; // scales 0~3
2331 scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
2332 memcpy(x0_scales, scales, 8);
2333 }
2334 {
2335 uint32_t scales_mins[3];
2336 memcpy(scales_mins, x1->scales, 12);
2337 const uint32_t mins_0_3 = scales_mins[1] & kmask1;
2338 const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
2339 const uint32x2_t mins = {mins_0_3, mins_4_7};
2340 x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
2341 uint32_t scales[2];
2342 scales[0] = scales_mins[0] & kmask1; // scales 0~3
2343 scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
2344 memcpy(x1_scales, scales, 8);
2345 }
2346
2347 int32x4_t visum = {0};
2348
2349 // process 64 data points per iteration, totally 256 data points
2350 for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
2351 const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
2352 const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
2353
2354 int8x16_t vx0[4], vx1[4];
2355 {
2356 const uint8x16x2_t vv = vld1q_u8_x2(qx0);
2357 vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
2358 vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
2359 vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
2360 vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
2361 }
2362 {
2363 const uint8x16x2_t vv = vld1q_u8_x2(qx1);
2364 vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
2365 vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
2366 vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
2367 vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
2368 }
2369
2370 // process 32 data points (share same block scale) per iteration
2371 for (int k = 0; k < 2; ++k) {
2372 const int blk = j * 2 + k;
2373 const int32x4_t block_scale = {
2374 x0_scales[blk],
2375 x0_scales[blk],
2376 x1_scales[blk],
2377 x1_scales[blk],
2378 };
2379
2380 int32x4_t vr = {0};
2381 for (int l = 0; l < 2; ++l) {
2382 const int idx = k * 2 + l;
2383 const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
2384 const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
2385 const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
2386 const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
2387 const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
2388 const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
2389 const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
2390 const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
2391 vr = vmmlaq_s32(vr, vx_l, vy_l);
2392 vr = vmmlaq_s32(vr, vx_h, vy_h);
2393 }
2394 // apply block scale, will NOT overflow
2395 // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
2396 visum = vmlaq_s32(visum, vr, block_scale);
2397 }
2398 }
2399
2400 // adjust bias, apply superblock scale
2401 {
2402 int32_t bias[4];
2403 // no obvious uplift from sve sdot-16, just use neon mul add
2404 const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
2405 const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
2406 bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
2407 vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
2408 bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
2409 vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
2410 bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
2411 vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
2412 bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
2413 vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
2414 const float32x4_t dmins = {
2415 GGML_CPU_FP16_TO_FP32(x0->dmin) * y0->d,
2416 GGML_CPU_FP16_TO_FP32(x0->dmin) * y1->d,
2417 GGML_CPU_FP16_TO_FP32(x1->dmin) * y0->d,
2418 GGML_CPU_FP16_TO_FP32(x1->dmin) * y1->d,
2419 };
2420 vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
2421
2422 const float32x4_t superblock_scale = {
2423 GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
2424 GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
2425 GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
2426 GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
2427 };
2428 vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
2429 }
2430 }
2431
2432 // vfsum = ABCD -> ACBD
2433 // AC -> s, BD -> (s+bs)
2434 vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
2435 vst1_f32(s, vget_low_f32 (vfsum));
2436 vst1_f32(s + bs, vget_high_f32(vfsum));
2437
2438 return;
2439 }
2440#endif
2441
2442#ifdef __ARM_FEATURE_SVE
2443 float sumf = 0;
2444 for (int i = 0; i < nb; ++i) {
2445
2446 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2447 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2448
2449 const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2450
2451 memcpy(utmp, x[i].scales, K_SCALE_SIZE);
2452
2453 uint32x2_t mins8 = { 0 };
2454 mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
2455 mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
2456
2457 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2458 utmp[0] &= kmask1;
2459
2460 const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
2461 const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2462 vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2463 sumf -= dmin * vaddvq_s32(prod);
2464
2465 const uint8_t * scales = (const uint8_t *)utmp;
2466
2467 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
2468 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2469
2470 const svuint8_t m4b = svdup_n_u8(0xf);
2471 const svint32_t mzero = svdup_n_s32(0);
2472 svint32_t sumi1 = svdup_n_s32(0);
2473 svint32_t sumi1_1 = svdup_n_s32(0);
2474 svint32_t sumi1_2 = svdup_n_s32(0);
2475 svint32_t sumi2 = svdup_n_s32(0);
2476 svint32_t sumi2_1 = svdup_n_s32(0);
2477 svint32_t sumi2_2 = svdup_n_s32(0);
2478 switch (vector_length) {
2479 case 128:
2480 {
2481 for (int j = 0; j < QK_K/64; ++j) {
2482 svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
2483 svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2484 sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2485 q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
2486 q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2487 sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2488
2489 q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
2490 q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2491 sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2492 q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
2493 q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2494 sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2495 q4 += 32;
2496 }
2497 sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
2498 sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
2499 sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
2500 } break;
2501 case 256:
2502 case 512:
2503 {
2504 for (int j = 0; j < QK_K/64; ++j) {
2505 const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
2506 svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
2507 svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
2508 sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2509
2510 q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
2511 q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
2512 sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2513 }
2514 sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
2515 } break;
2516 default:
2517 assert(false && "Unsupported vector length");
2518 break;
2519 }
2520 }
2521 *s = sumf;
2522#elif defined __ARM_NEON
2523 const uint8x16_t m4b = vdupq_n_u8(0xf);
2524 const int32x4_t mzero = vdupq_n_s32(0);
2525
2526 ggml_int8x16x2_t q4bytes;
2527 ggml_int8x16x2_t q8bytes;
2528
2529 float sumf = 0;
2530
2531 for (int i = 0; i < nb; ++i) {
2532
2533 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2534 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2535
2536 const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2537
2538 memcpy(utmp, x[i].scales, 12);
2539
2540 uint32x2_t mins8 = { 0 };
2541 mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
2542 mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
2543
2544 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2545 utmp[0] &= kmask1;
2546
2547 const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
2548 const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2549 vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2550 sumf -= dmin * vaddvq_s32(prod);
2551
2552 const uint8_t * scales = (const uint8_t *)utmp;
2553
2554 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
2555 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2556
2557 int32_t sumi1 = 0;
2558 int32_t sumi2 = 0;
2559
2560 for (int j = 0; j < QK_K/64; ++j) {
2561 const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
2562
2563 q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
2564 q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
2565 q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
2566
2567 const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2568 sumi1 += vaddvq_s32(p1) * scales[2*j+0];
2569
2570 q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
2571 q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2572 q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2573
2574 const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2575
2576 sumi2 += vaddvq_s32(p2) * scales[2*j+1];
2577 }
2578
2579 sumf += d * (sumi1 + sumi2);
2580
2581 }
2582
2583 *s = sumf;
2584
2585#else
2586 UNUSED(x);
2587 UNUSED(y);
2588 UNUSED(nb);
2589 UNUSED(kmask1);
2590 UNUSED(kmask2);
2591 UNUSED(kmask3);
2592 UNUSED(utmp);
2593 ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2594#endif
2595}
2596
2597void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2598 assert(n % QK_K == 0);
2599 assert(nrc == 1);
2600 UNUSED(nrc);
2601 UNUSED(bx);
2602 UNUSED(by);
2603 UNUSED(bs);
2604
2605 const block_q5_K * GGML_RESTRICT x = vx;
2606 const block_q8_K * GGML_RESTRICT y = vy;
2607
2608 const int nb = n / QK_K;
2609
2610 static const uint32_t kmask1 = 0x3f3f3f3f;
2611 static const uint32_t kmask2 = 0x0f0f0f0f;
2612 static const uint32_t kmask3 = 0x03030303;
2613
2614 uint32_t utmp[4];
2615
2616
2617#ifdef __ARM_NEON
2618 const uint8x16_t m4b = vdupq_n_u8(0xf);
2619 const uint8x16_t mone = vdupq_n_u8(1);
2620 const uint8x16_t mtwo = vdupq_n_u8(2);
2621 const int32x4_t mzero = vdupq_n_s32(0);
2622
2623 ggml_int8x16x4_t q5bytes;
2624
2625 float sumf = 0;
2626
2627 for (int i = 0; i < nb; ++i) {
2628
2629 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2630 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2631
2632 const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2633
2634 memcpy(utmp, x[i].scales, 12);
2635 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2636 const uint32_t uaux = utmp[1] & kmask1;
2637 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2638 utmp[2] = uaux;
2639 utmp[0] &= kmask1;
2640
2641 const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
2642 const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
2643 const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2644 vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2645 int32_t sumi_mins = vaddvq_s32(prod);
2646
2647 const uint8_t * scales = (const uint8_t *)utmp;
2648
2649 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
2650 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2651 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2652
2653 ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
2654
2655 ggml_uint8x16x4_t q5h;
2656
2657 int32_t sumi = 0;
2658
2659 for (int j = 0; j < QK_K/64; ++j) {
2660
2661 const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
2662 const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
2663
2664 q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
2665 q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
2666 q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
2667 q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
2668 qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
2669 qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
2670
2671 q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
2672 q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
2673 q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
2674 q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
2675
2676 sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
2677 sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
2678 }
2679
2680 sumf += d * sumi - dmin * sumi_mins;
2681 }
2682
2683 *s = sumf;
2684
2685#else
2686 UNUSED(x);
2687 UNUSED(y);
2688 UNUSED(nb);
2689 UNUSED(kmask1);
2690 UNUSED(kmask2);
2691 UNUSED(kmask3);
2692 UNUSED(utmp);
2693 ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2694#endif
2695}
2696
2697void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2698 assert(n % QK_K == 0);
2699#ifdef __ARM_FEATURE_MATMUL_INT8
2700 assert((nrc == 2) || (nrc == 1));
2701#else
2702 assert(nrc == 1);
2703#endif
2704 UNUSED(nrc);
2705 UNUSED(bx);
2706 UNUSED(by);
2707 UNUSED(bs);
2708
2709 const block_q6_K * GGML_RESTRICT x = vx;
2710 const block_q8_K * GGML_RESTRICT y = vy;
2711
2712 const int nb = n / QK_K;
2713
2714#ifdef __ARM_FEATURE_SVE
2715 const int vector_length = ggml_cpu_get_sve_cnt()*8;
2716#endif
2717#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2718 if (nrc == 2) {
2719 const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
2720
2721 svfloat32_t sum = svdup_n_f32(0);
2722
2723 const block_q6_K * GGML_RESTRICT vx0 = vx;
2724 const block_q8_K * GGML_RESTRICT vy0 = vy;
2725 const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
2726 const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
2727
2728 switch (vector_length) {
2729 case 128:
2730 {
2731 const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
2732 for (int i = 0; i < nb; ++i) {
2733 const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
2734 const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
2735 const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
2736 const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
2737 const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
2738 const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
2739
2740 const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
2741 const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
2742
2743 svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2744 svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2745 svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
2746 // process q8sum summation 128 bit route
2747 const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
2748 const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
2749 const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
2750 const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
2751 const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
2752 const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
2753 const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
2754 const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
2755 const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
2756 const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
2757 const svint64_t prod = svdup_n_s64(0);
2758
2759 svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
2760 svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
2761 svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
2762 svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
2763 svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
2764 svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
2765 svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
2766 svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
2767 svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
2768
2769 // process mmla
2770 svint8_t l0, l1, r0, r1;
2771 svint32_t isum_tmp = svdup_n_s32(0);
2772 for (int j = 0; j < QK_K/128; ++j) {
2773 for (int k = 0; k < 8; ++k) {
2774 svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
2775 svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
2776 svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
2777 svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
2778 const int ql_pos = (k/4)*4;
2779 svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
2780 svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
2781 const int qh_pos = (k/2)*2;
2782 svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
2783 svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
2784 svint8_t q6bytes_0, q6bytes_1;
2785 if (qh_pos <= 4) {
2786 q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
2787 q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
2788 } else {
2789 q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
2790 q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
2791 }
2792 svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
2793 svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
2794 l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2795 l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2796 r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2797 r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2798 svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
2799 isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
2800 }
2801 qh0 += 32; qh1 += 32;
2802 ql0 += 64; ql1 += 64;
2803 q80 += 128; q81 += 128;
2804 scale0 += 8; scale1 += 8;
2805 }
2806 sum = svmla_f32_x(pg128_all, sum,
2807 svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
2808 svisum_mins, svdup_n_s32(-32))),
2809 svsuper_block_scales);
2810 }
2811 } // end of case 128
2812 break;
2813 case 256:
2814 case 512:
2815 {
2816 const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
2817 const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2818 for (int i = 0; i < nb; ++i) {
2819 const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
2820 const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
2821 const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
2822 const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
2823 const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
2824 const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
2825
2826 const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
2827 const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
2828 svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
2829 svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2830 svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
2831 svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
2832 // process q8sum summation 256 bit route
2833 const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
2834 const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
2835 const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
2836 const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
2837 const svint64_t prod = svdup_n_s64(0);
2838 svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
2839 svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
2840 svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
2841 svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
2842 svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
2843 svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
2844 svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
2845 svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
2846 svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
2847 svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
2848 svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
2849
2850 // process mmla
2851 svint8_t l0, l1, r0, r1;
2852 svint32_t isum_tmp = svdup_n_s32(0);
2853 for (int j = 0; j < QK_K/128; ++j) {
2854 for (int k = 0; k < 8; k+=2) { // process 2 block
2855 svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
2856 svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
2857 svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
2858 svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
2859 const int ql_pos = (k/4)*4;
2860 svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
2861 svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
2862 const int qh_pos = (k/2)*2;
2863 svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
2864 svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
2865 svint8_t q6bytes_0, q6bytes_1;
2866 if (qh_pos <= 4) {
2867 q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
2868 q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
2869 } else {
2870 q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
2871 q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
2872 }
2873 svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
2874 svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
2875 l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2876 l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2877 r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2878 r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2879 svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
2880 svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
2881 isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
2882 isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
2883 }
2884 qh0 += 32; qh1 += 32;
2885 ql0 += 64; ql1 += 64;
2886 q80 += 128; q81 += 128;
2887 scale0 += 8; scale1 += 8;
2888 } // end of for
2889 svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
2890 isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
2891 sum = svmla_f32_x(pg32_4, sum,
2892 svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
2893 svisum_mins, svdup_n_s32(-32))),
2894 svsuper_block_scales);
2895 }
2896 } // end of case 256
2897 break;
2898 default:
2899 assert(false && "Unsupported vector length");
2900 break;
2901 } // end of switch
2902
2903 svst1_f32(pg32_2, s, sum);
2904 svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
2905
2906 return;
2907 }
2908#elif defined(__ARM_FEATURE_MATMUL_INT8)
2909 if (nrc == 2) {
2910 const block_q6_K * GGML_RESTRICT x0 = x;
2911 const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
2912 const block_q8_K * GGML_RESTRICT y0 = y;
2913 const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
2914
2915 float32x4_t vfsum = vdupq_n_f32(0.0f);
2916
2917 for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
2918 const uint8_t * GGML_RESTRICT ql0 = x0->ql;
2919 const uint8_t * GGML_RESTRICT ql1 = x1->ql;
2920 const uint8_t * GGML_RESTRICT qh0 = x0->qh;
2921 const uint8_t * GGML_RESTRICT qh1 = x1->qh;
2922 const int8_t * GGML_RESTRICT qy0 = y0->qs;
2923 const int8_t * GGML_RESTRICT qy1 = y1->qs;
2924
2925 const uint8x16_t mone = vdupq_n_u8(0x30);
2926 const uint8x16_t m4b = vdupq_n_u8(0x0f);
2927
2928 int32x4_t visum = vdupq_n_s32(0);
2929
2930 // process 8 blocks per iteration, totally 16 blocks
2931 for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
2932 int8x16_t vx0[8], vx1[8];
2933
2934 // de-quantize vx0[8]
2935 {
2936 const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
2937 const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
2938
2939 uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
2940 uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
2941 uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
2942 uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
2943
2944 vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
2945 vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
2946 vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
2947 vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
2948
2949 q6h_0 = vandq_u8(mone, qh_bits.val[0]);
2950 q6h_1 = vandq_u8(mone, qh_bits.val[1]);
2951 q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
2952 q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
2953
2954 vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
2955 vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
2956 vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
2957 vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
2958 }
2959
2960 // de-quantize vx1[8]
2961 {
2962 const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
2963 const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
2964
2965 uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
2966 uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
2967 uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
2968 uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
2969
2970 vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
2971 vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
2972 vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
2973 vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
2974
2975 q6h_0 = vandq_u8(mone, qh_bits.val[0]);
2976 q6h_1 = vandq_u8(mone, qh_bits.val[1]);
2977 q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
2978 q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
2979
2980 vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
2981 vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
2982 vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
2983 vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
2984 }
2985
2986 // process 16 elements (one block with same scale) per iteration
2987 // - vx = concat(ql, qh) - 32
2988 // - r1,r2,r3,r4 = smmla(vx, vy)
2989 for (int k = 0; k < 8; ++k) {
2990 const int blk = j * 8 + k;
2991
2992 const int8x16_t vy0 = vld1q_s8(qy0);
2993 const int8x16_t vy1 = vld1q_s8(qy1);
2994 qy0 += 16;
2995 qy1 += 16;
2996
2997 const int32x4_t block_scale = {
2998 x0->scales[blk],
2999 x0->scales[blk],
3000 x1->scales[blk],
3001 x1->scales[blk],
3002 };
3003
3004 // calculate four results at once with outer product
3005 const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
3006 const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
3007 const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
3008 const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
3009 int32x4_t vr = vdupq_n_s32(0);
3010 vr = vmmlaq_s32(vr, vx_l, vy_l);
3011 vr = vmmlaq_s32(vr, vx_h, vy_h);
3012
3013 // apply block scale, will NOT overflow
3014 // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
3015 visum = vmlaq_s32(visum, vr, block_scale);
3016 }
3017 }
3018
3019 // adjust bias, apply superblock scale
3020 {
3021 int32_t bias[4];
3022 // NEON doesn't support int16 dot product, fallback to separated mul and add
3023 const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
3024 const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
3025
3026 int8x16_t scales_s8 = vld1q_s8(x0->scales);
3027 const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
3028 scales_s8 = vld1q_s8(x1->scales);
3029 const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
3030
3031 int32x4_t prod;
3032 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
3033 vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
3034 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
3035 vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
3036 bias[0] = vaddvq_s32(prod);
3037 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
3038 vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
3039 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
3040 vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
3041 bias[1] = vaddvq_s32(prod);
3042 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
3043 vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
3044 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
3045 vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
3046 bias[2] = vaddvq_s32(prod);
3047 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
3048 vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
3049 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
3050 vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
3051 bias[3] = vaddvq_s32(prod);
3052
3053 const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
3054
3055 const float32x4_t superblock_scale = {
3056 GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
3057 GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
3058 GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
3059 GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
3060 };
3061
3062 visum = vsubq_s32(visum, vibias);
3063 vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
3064 }
3065 }
3066
3067 // vfsum = ABCD -> ACBD
3068 // AC -> s, BD -> (s+bs)
3069 vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
3070 vst1_f32(s, vget_low_f32 (vfsum));
3071 vst1_f32(s + bs, vget_high_f32(vfsum));
3072
3073 return;
3074 }
3075#endif
3076
3077#ifdef __ARM_FEATURE_SVE
3078 float sum = 0;
3079 svuint8_t m4b = svdup_n_u8(0xf);
3080 svint32_t vzero = svdup_n_s32(0);
3081 svuint8_t mone = svdup_n_u8(0x30);
3082 svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
3083 svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
3084
3085 for (int i = 0; i < nb; ++i) {
3086 const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);
3087
3088 const uint8_t * GGML_RESTRICT q6 = x[i].ql;
3089 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3090 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3091
3092 const int8_t * GGML_RESTRICT scale = x[i].scales;
3093
3094 const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
3095 const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
3096 const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
3097 const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
3098 const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
3099 const svint64_t prod = svdup_n_s64(0);
3100 int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
3101 svdot_s64(prod, q8sums_2, q6scales_2)));
3102 int32_t isum = 0;
3103
3104 switch (vector_length) {
3105 case 128:
3106 {
3107 const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
3108 const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
3109 svint32_t isum_tmp = svdup_n_s32(0);
3110 for (int j = 0; j < QK_K/128; ++j) {
3111 svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
3112 svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
3113 qh += 32;
3114 svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
3115 svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
3116 svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
3117 svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
3118 q6 += 64;
3119 svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
3120 svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
3121 svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
3122 svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
3123 q8 += 64;
3124
3125 q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
3126 q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
3127 q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
3128 q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
3129 q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
3130 q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
3131 q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
3132 q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
3133 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
3134 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
3135 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
3136 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
3137
3138 scale += 4;
3139 q8bytes_1 = svld1_s8(pg8_16, q8);
3140 q8bytes_2 = svld1_s8(pg8_16, q8+16);
3141 q8bytes_3 = svld1_s8(pg8_16, q8+32);
3142 q8bytes_4 = svld1_s8(pg8_16, q8+48);
3143 q8 += 64;
3144
3145 q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
3146 q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
3147 q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
3148 q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
3149 q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
3150 q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
3151 q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
3152 q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
3153 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
3154 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
3155 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
3156 isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
3157 scale += 4;
3158 }
3159 isum += svaddv_s32(pg32_4, isum_tmp);
3160 sum += d_all * y[i].d * (isum - 32 * isum_mins);
3161 }
3162 break;
3163 case 256:
3164 case 512:
3165 {
3166 const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
3167 const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
3168 const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
3169 svint32_t isum_tmp = svdup_n_s32(0);
3170 for (int j = 0; j < QK_K/128; j++) {
3171 svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
3172 qh += 32;
3173 svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
3174 svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
3175 q6 += 64;
3176 svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
3177 svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
3178 svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
3179 svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
3180 q8 += 128;
3181 q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
3182 q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
3183 q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
3184 q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
3185 q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
3186 q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
3187 q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
3188 q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
3189
3190 svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
3191 scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
3192 scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
3193 svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
3194 scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
3195 scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
3196 svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
3197 scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
3198 scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
3199 svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
3200 scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
3201 scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
3202 svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
3203 svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
3204 svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
3205 svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
3206
3207 isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
3208 isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
3209 isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
3210 isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
3211 scale += 8;
3212 }
3213 isum += svaddv_s32(pg32_8, isum_tmp);
3214 sum += d_all * y[i].d * (isum - 32 * isum_mins);
3215 }
3216 break;
3217 default:
3218 assert(false && "Unsupported vector length");
3219 break;
3220 }
3221 }
3222
3223 *s = sum;
3224
3225#elif __ARM_NEON
3226 float sum = 0;
3227
3228 const uint8x16_t m4b = vdupq_n_u8(0xF);
3229 const int32x4_t vzero = vdupq_n_s32(0);
3230 //const int8x16_t m32s = vdupq_n_s8(32);
3231
3232 const uint8x16_t mone = vdupq_n_u8(3);
3233
3234 ggml_int8x16x4_t q6bytes;
3235 ggml_uint8x16x4_t q6h;
3236
3237 for (int i = 0; i < nb; ++i) {
3238
3239 const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);
3240
3241 const uint8_t * GGML_RESTRICT q6 = x[i].ql;
3242 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3243 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3244
3245 const int8_t * GGML_RESTRICT scale = x[i].scales;
3246
3247 const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
3248 const int8x16_t scales = vld1q_s8(scale);
3249 const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
3250
3251 const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
3252 vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
3253 vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
3254 vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
3255 int32_t isum_mins = vaddvq_s32(prod);
3256
3257 int32_t isum = 0;
3258
3259 for (int j = 0; j < QK_K/128; ++j) {
3260
3261 ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
3262 ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
3263 ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
3264
3265 q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
3266 q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
3267 uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
3268 q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3269 shifted = vshrq_n_u8(qhbits.val[1], 2);
3270 q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3271
3272 //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
3273 //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
3274 //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
3275 //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
3276 q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
3277 q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
3278 q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
3279 q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
3280
3281 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3282 vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3283 vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3284 vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3285
3286 scale += 4;
3287
3288 q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
3289
3290 shifted = vshrq_n_u8(qhbits.val[0], 4);
3291 q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3292 shifted = vshrq_n_u8(qhbits.val[1], 4);
3293 q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3294 shifted = vshrq_n_u8(qhbits.val[0], 6);
3295 q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3296 shifted = vshrq_n_u8(qhbits.val[1], 6);
3297 q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3298
3299 //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
3300 //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
3301 //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
3302 //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
3303 q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
3304 q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
3305 q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
3306 q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
3307
3308 isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3309 vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3310 vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3311 vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3312 scale += 4;
3313 }
3314 //sum += isum * d_all * y[i].d;
3315 sum += d_all * y[i].d * (isum - 32 * isum_mins);
3316
3317 }
3318 *s = sum;
3319#else
3320 UNUSED(x);
3321 UNUSED(y);
3322 UNUSED(nb);
3323 ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3324#endif
3325}
3326
3327#if defined (__ARM_NEON)
3328static const int8_t keven_signs_q2xs[1024] = {
3329 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
3330 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
3331 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
3332 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
3333 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
3334 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
3335 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
3336 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
3337 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
3338 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
3339 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
3340 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
3341 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
3342 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
3343 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
3344 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
3345 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
3346 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
3347 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
3348 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
3349 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
3350 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
3351 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
3352 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
3353 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
3354 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
3355 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
3356 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
3357 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
3358 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
3359 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
3360 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
3361};
3362#endif
3363
3364void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3365 assert(n % QK_K == 0);
3366 assert(nrc == 1);
3367 UNUSED(nrc);
3368 UNUSED(bx);
3369 UNUSED(by);
3370 UNUSED(bs);
3371
3372 const block_iq2_xxs * GGML_RESTRICT x = vx;
3373 const block_q8_K * GGML_RESTRICT y = vy;
3374
3375 const int nb = n / QK_K;
3376
3377#if defined(__ARM_NEON)
3378
3379 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3380
3381 uint32_t aux32[4];
3382 const uint8_t * aux8 = (const uint8_t *)aux32;
3383
3384 ggml_int8x16x4_t q2u;
3385 ggml_int8x16x4_t q2s;
3386 ggml_int8x16x4_t q8b;
3387
3388 float sumf = 0;
3389 for (int i = 0; i < nb; ++i) {
3390 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3391 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
3392 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3393 float sumf1 = 0, sumf2 = 0;
3394 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3395 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3396 memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
3397 q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
3398 q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
3399 q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
3400 q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
3401 q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
3402 q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
3403 q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
3404 q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
3405 q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
3406 q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
3407 q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
3408 q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
3409 const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
3410 const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
3411 sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
3412 sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
3413 }
3414 sumf += d*(sumf1 + sumf2);
3415 }
3416 *s = 0.25f * sumf;
3417
3418#else
3419 UNUSED(x);
3420 UNUSED(y);
3421 UNUSED(nb);
3422 ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3423#endif
3424}
3425
3426void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3427 assert(n % QK_K == 0);
3428 assert(nrc == 1);
3429 UNUSED(nrc);
3430 UNUSED(bx);
3431 UNUSED(by);
3432 UNUSED(bs);
3433
3434 const block_iq2_xs * GGML_RESTRICT x = vx;
3435 const block_q8_K * GGML_RESTRICT y = vy;
3436
3437 const int nb = n / QK_K;
3438
3439#if defined(__ARM_NEON)
3440
3441 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3442
3443 ggml_int8x16x4_t q2u;
3444 ggml_int8x16x4_t q2s;
3445 ggml_int8x16x4_t q8b;
3446
3447 int32x4x4_t scales32;
3448
3449 float sumf = 0;
3450 for (int i = 0; i < nb; ++i) {
3451 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3452 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
3453 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3454 const uint8x8_t scales8 = vld1_u8(x[i].scales);
3455 const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
3456 const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
3457 uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
3458 scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
3459 const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
3460 const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
3461 scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
3462 scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
3463 scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
3464 scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
3465 int32x4_t sumi = vdupq_n_s32(0);
3466 for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
3467 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3468 q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
3469 q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
3470 q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
3471 q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
3472 q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
3473 q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
3474 q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
3475 q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
3476 q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
3477 q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
3478 q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
3479 q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
3480 const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
3481 const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
3482 const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
3483 const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
3484 const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
3485 sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
3486 q2 += 8;
3487 }
3488 sumf += d*vaddvq_s32(sumi);
3489 }
3490 *s = 0.125f * sumf;
3491
3492#else
3493 UNUSED(x);
3494 UNUSED(y);
3495 UNUSED(nb);
3496 ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3497#endif
3498}
3499
3500void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3501 assert(n % QK_K == 0);
3502 assert(nrc == 1);
3503 UNUSED(nrc);
3504 UNUSED(bx);
3505 UNUSED(by);
3506 UNUSED(bs);
3507
3508 const block_iq2_s * GGML_RESTRICT x = vx;
3509 const block_q8_K * GGML_RESTRICT y = vy;
3510
3511 const int nb = n / QK_K;
3512
3513#if defined(__ARM_NEON)
3514
3515 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3516 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3517 };
3518
3519 static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
3520
3521 const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
3522 const uint8x16_t mask2 = vld1q_u8(k_mask2);
3523 const uint8x16_t m1 = vdupq_n_u8(1);
3524 const int32x4_t vzero = vdupq_n_s32(0);
3525
3526 uint8x16x2_t vs;
3527 ggml_int8x16x4_t q2s;
3528 ggml_int8x16x4_t q8b;
3529
3530 float sumf = 0;
3531 for (int i = 0; i < nb; ++i) {
3532
3533 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3534
3535 const uint8_t * GGML_RESTRICT qs = x[i].qs;
3536 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3537 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
3538 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3539
3540 int sumi1 = 0, sumi2 = 0;
3541 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3542 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3543 q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
3544 vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
3545 q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
3546 vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
3547 q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
3548 vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
3549 q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
3550 vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
3551 qs += 8;
3552
3553 vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
3554 vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3555 vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3556 vs.val[0] = vceqq_u8(vs.val[0], mask2);
3557 vs.val[1] = vceqq_u8(vs.val[1], mask2);
3558
3559 q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
3560 q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
3561
3562 vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
3563 vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3564 vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3565 vs.val[0] = vceqq_u8(vs.val[0], mask2);
3566 vs.val[1] = vceqq_u8(vs.val[1], mask2);
3567
3568 signs += 4;
3569
3570 q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
3571 q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
3572
3573 const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
3574 const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
3575 const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
3576 const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
3577
3578 sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
3579 sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4));
3580 sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
3581 sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4));
3582 }
3583 sumf += d*(sumi1 + sumi2);
3584 }
3585
3586 *s = 0.125f * sumf;
3587
3588#else
3589 UNUSED(x);
3590 UNUSED(y);
3591 UNUSED(nb);
3592 ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3593#endif
3594
3595}
3596
3597void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3598 assert(n % QK_K == 0);
3599 assert(nrc == 1);
3600 UNUSED(nrc);
3601 UNUSED(bx);
3602 UNUSED(by);
3603 UNUSED(bs);
3604
3605 const block_iq3_xxs * GGML_RESTRICT x = vx;
3606 const block_q8_K * GGML_RESTRICT y = vy;
3607
3608 const int nb = n / QK_K;
3609
3610#if defined(__ARM_NEON)
3611
3612 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3613
3614 uint32_t aux32[2];
3615
3616 ggml_int8x16x4_t q3s;
3617 ggml_int8x16x4_t q8b;
3618
3619 float sumf = 0;
3620 for (int i = 0; i < nb; ++i) {
3621 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3622 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
3623 const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
3624 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3625 float sumf1 = 0, sumf2 = 0;
3626 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3627 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3628 memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
3629 const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
3630 const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
3631 const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
3632 const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
3633 q3 += 16;
3634 q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
3635 q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
3636 q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
3637 q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
3638 q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
3639 q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
3640 q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
3641 q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
3642 const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
3643 const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
3644 sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
3645 sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
3646 }
3647 sumf += d*(sumf1 + sumf2);
3648 }
3649 *s = 0.5f * sumf;
3650
3651#else
3652 UNUSED(x);
3653 UNUSED(y);
3654 UNUSED(nb);
3655 ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3656#endif
3657}
3658
3659void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3660 assert(n % QK_K == 0);
3661 assert(nrc == 1);
3662 UNUSED(nrc);
3663 UNUSED(bx);
3664 UNUSED(by);
3665 UNUSED(bs);
3666
3667 const block_iq3_s * GGML_RESTRICT x = vx;
3668 const block_q8_K * GGML_RESTRICT y = vy;
3669
3670 const int nb = n / QK_K;
3671
3672#if defined(__ARM_NEON)
3673
3674 typedef union {
3675 uint16x8_t vec_index;
3676 uint16_t index[8];
3677 } vec_index_t;
3678
3679 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3680 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3681 };
3682
3683 static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
3684
3685 static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
3686
3687 const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
3688 const uint8x16_t mask2 = vld1q_u8(k_mask2);
3689
3690 const int16x8_t hshift = vld1q_s16(k_shift);
3691 const uint16x8_t m256 = vdupq_n_u16(256);
3692 const uint8x16_t m1 = vdupq_n_u8(1);
3693
3694 uint8x16x2_t vs;
3695 ggml_int8x16x4_t q3s;
3696 ggml_int8x16x4_t q8b;
3697 vec_index_t idx;
3698
3699 uint32_t scales32[2];
3700 const uint8_t * scales8 = (const uint8_t *)scales32;
3701
3702 float sumf = 0;
3703 for (int i = 0; i < nb; ++i) {
3704 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3705 const uint8_t * GGML_RESTRICT qs = x[i].qs;
3706 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3707 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3708 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3709
3710 memcpy(scales32, x[i].scales, 4);
3711 scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
3712 scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
3713
3714 int sumi1 = 0, sumi2 = 0;
3715 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3716 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3717
3718 const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
3719 idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
3720 const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
3721 iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
3722 const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
3723 iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
3724 idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
3725 const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
3726 iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
3727 const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
3728 iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
3729
3730
3731 vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
3732 vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3733 vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3734 vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
3735 vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
3736
3737 q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
3738 q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
3739
3740 vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
3741 vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3742 vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3743 vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
3744 vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
3745
3746 signs += 4;
3747
3748 q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
3749 q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
3750
3751 const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
3752 const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
3753
3754 sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
3755 sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
3756 }
3757 sumf += d*(sumi1 + sumi2);
3758 }
3759 *s = sumf;
3760
3761#else
3762 UNUSED(x);
3763 UNUSED(y);
3764 UNUSED(nb);
3765 ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3766#endif
3767}
3768
3769void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3770 assert(n % QK_K == 0);
3771 assert(nrc == 1);
3772 UNUSED(nrc);
3773 UNUSED(bx);
3774 UNUSED(by);
3775 UNUSED(bs);
3776
3777 const block_iq1_s * GGML_RESTRICT x = vx;
3778 const block_q8_K * GGML_RESTRICT y = vy;
3779
3780 const int nb = n / QK_K;
3781
3782#if defined __ARM_NEON
3783
3784 ggml_int8x16x4_t q1b;
3785 ggml_int8x16x4_t q8b;
3786
3787 float sumf = 0;
3788 for (int i = 0; i < nb; ++i) {
3789
3790 const int8_t * q8 = y[i].qs;
3791 const uint8_t * qs = x[i].qs;
3792 const uint16_t * qh = x[i].qh;
3793
3794 int sumi1 = 0, sumi2 = 0, sumi3 = 0;
3795
3796 for (int ib = 0; ib < QK_K/32; ib += 2) {
3797
3798 q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
3799 vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
3800 q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
3801 vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
3802 q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
3803 vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
3804 q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
3805 vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
3806 qs += 8;
3807
3808 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3809
3810 const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
3811 const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
3812
3813 const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3814 const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3815 sumi1 += vaddvq_s32(p1) * ls1;
3816 sumi2 += vaddvq_s32(p2) * ls2;
3817 sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
3818 + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
3819
3820 }
3821
3822 sumf += y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
3823 }
3824
3825 *s = sumf;
3826
3827#else
3828 UNUSED(x);
3829 UNUSED(y);
3830 UNUSED(nb);
3831 ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3832#endif
3833}
3834
3835void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3836 assert(n % QK_K == 0);
3837 assert(nrc == 1);
3838 UNUSED(nrc);
3839 UNUSED(bx);
3840 UNUSED(by);
3841 UNUSED(bs);
3842
3843 const block_iq1_m * GGML_RESTRICT x = vx;
3844 const block_q8_K * GGML_RESTRICT y = vy;
3845
3846 const int nb = n / QK_K;
3847
3848 iq1m_scale_t scale;
3849
3850#if defined __ARM_NEON
3851 const int32x4_t mask = vdupq_n_s32(0x7);
3852 const int32x4_t mone = vdupq_n_s32(1);
3853 const int32x4_t mzero = vdupq_n_s32(0);
3854
3855 ggml_int8x16x4_t deltas;
3856 deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
3857 deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
3858 deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
3859 deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
3860
3861 ggml_int8x16x4_t q1b;
3862 ggml_int8x16x4_t q8b;
3863
3864 uint32_t aux32;
3865 const uint8_t * aux8 = (const uint8_t *)&aux32;
3866
3867 float sumf = 0;
3868 for (int i = 0; i < nb; ++i) {
3869
3870 const int8_t * q8 = y[i].qs;
3871 const uint8_t * qs = x[i].qs;
3872 const uint8_t * qh = x[i].qh;
3873 const uint16_t * sc = (const uint16_t *)x[i].scales;
3874
3875 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3876
3877 int32x4_t sumi1 = mzero;
3878 int32x4_t sumi2 = mzero;
3879
3880 for (int ib = 0; ib < QK_K/32; ib += 2) {
3881
3882 q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
3883 vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
3884 q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
3885 vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
3886 q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
3887 vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
3888 q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
3889 vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
3890
3891 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
3892
3893 const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
3894 const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
3895 const int32x4_t p12 = vpaddq_s32(p1, p2);
3896
3897 const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
3898 aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
3899
3900 const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
3901 const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
3902 const int32x4_t p34 = vpaddq_s32(p3, p4);
3903
3904 int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
3905
3906 scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
3907
3908 sumi1 = vmlaq_s32(sumi1, scales_4, p12);
3909 sumi2 = vmlaq_s32(sumi2, scales_4, p34);
3910
3911 qs += 8; qh += 4;
3912
3913 }
3914
3915 sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
3916 }
3917
3918 *s = sumf;
3919
3920#else
3921 UNUSED(x);
3922 UNUSED(y);
3923 UNUSED(nb);
3924 UNUSED(scale);
3925 ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3926#endif
3927}
3928
3929void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3930 assert(nrc == 1);
3931 UNUSED(nrc);
3932 UNUSED(bx);
3933 UNUSED(by);
3934 UNUSED(bs);
3935 assert(n % QK4_NL == 0);
3936 static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
3937
3938 const block_iq4_nl * GGML_RESTRICT x = vx;
3939 const block_q8_0 * GGML_RESTRICT y = vy;
3940
3941 const int nb = n / QK4_NL;
3942
3943 int ib = 0;
3944 float sumf = 0;
3945
3946#if defined __ARM_NEON
3947 const int8x16_t values = vld1q_s8(kvalues_iq4nl);
3948 const uint8x16_t m4b = vdupq_n_u8(0x0f);
3949 uint8x16x2_t q4bits;
3950 int8x16x4_t q4b;
3951 int8x16x4_t q8b;
3952 int32x4_t prod_1, prod_2;
3953
3954 for (; ib + 1 < nb; ib += 2) {
3955
3956 q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
3957 q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
3958 q8b.val[0] = vld1q_s8(y[ib + 0].qs);
3959 q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
3960 q8b.val[2] = vld1q_s8(y[ib + 1].qs);
3961 q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
3962
3963 q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
3964 q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
3965 q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
3966 q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
3967
3968 prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
3969 prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
3970
3971 sumf +=
3972 GGML_CPU_FP16_TO_FP32(x[ib+0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
3973 GGML_CPU_FP16_TO_FP32(x[ib+1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
3974 }
3975
3976#endif
3977 for (; ib < nb; ++ib) {
3978 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
3979 int sumi1 = 0, sumi2 = 0;
3980 for (int j = 0; j < QK4_NL/2; ++j) {
3981 sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
3982 sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
3983 }
3984 sumf += d * (sumi1 + sumi2);
3985 }
3986 *s = sumf;
3987}
3988
3989void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3990 assert(nrc == 1);
3991 UNUSED(nrc);
3992 UNUSED(bx);
3993 UNUSED(by);
3994 UNUSED(bs);
3995 assert(n % QK_K == 0);
3996
3997 const block_iq4_xs * GGML_RESTRICT x = vx;
3998 const block_q8_K * GGML_RESTRICT y = vy;
3999
4000 const int nb = n / QK_K;
4001
4002#if defined __ARM_NEON
4003 const int8x16_t values = vld1q_s8(kvalues_iq4nl);
4004 const uint8x16_t m4b = vdupq_n_u8(0x0f);
4005 ggml_uint8x16x2_t q4bits;
4006 ggml_int8x16x4_t q4b;
4007 ggml_int8x16x4_t q8b;
4008 int32x4_t prod_1, prod_2;
4009
4010 float sumf = 0;
4011
4012 for (int ibl = 0; ibl < nb; ++ibl) {
4013
4014 const int8_t * q8 = y[ibl].qs;
4015 const uint8_t * q4 = x[ibl].qs;
4016 uint16_t h = x[ibl].scales_h;
4017
4018 int sumi1 = 0, sumi2 = 0;
4019 for (int ib = 0; ib < QK_K/64; ++ib) {
4020
4021 q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
4022 q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
4023
4024 q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
4025 q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
4026 q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
4027 q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
4028
4029 prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
4030 prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
4031
4032 int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
4033 int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
4034 h >>= 4;
4035 sumi1 += vaddvq_s32(prod_1) * ls1;
4036 sumi2 += vaddvq_s32(prod_2) * ls2;
4037
4038 }
4039
4040 sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
4041 }
4042
4043 *s = sumf;
4044
4045#else
4046 UNUSED(x);
4047 UNUSED(y);
4048 UNUSED(nb);
4049 ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4050#endif
4051}
4052