1#define GGML_COMMON_IMPL_CPP
2#define GGML_COMMON_DECL_CPP
3#include "ggml-common.h"
4#include "ggml-backend-impl.h"
5
6#include "ggml-impl.h"
7#include "ggml-cpu.h"
8#include "ggml-cpu-impl.h"
9#include "simd-mappings.h"
10#include "traits.h"
11
12#include <cmath>
13#include <cstring>
14#include <cassert>
15#include <cstdlib> // for qsort
16#include <cstdio> // for GGML_ASSERT
17
18#define GGML_CPU_CLANG_WORKAROUND
19#include "../../repack.h"
20
21#if defined(__GNUC__)
22#pragma GCC diagnostic ignored "-Woverlength-strings"
23#endif
24
25#define UNUSED GGML_UNUSED
26
27#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
28// Helper for decoding scales and mins of Q4_K and Q5_K block formats
29static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
30 constexpr uint32_t kmask1 = 0x3f3f3f3f;
31 constexpr uint32_t kmask2 = 0x0f0f0f0f;
32 constexpr uint32_t kmask3 = 0x03030303;
33 constexpr uint8_t scales_size = 12;
34
35 uint32_t sm[3];
36 memcpy(sm, scales_in, scales_size);
37
38 const uint32_t mins_0_3 = sm[1] & kmask1;
39 const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
40 const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
41
42 *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
43
44 uint32_t scales_u32[2];
45 scales_u32[0] = sm[0] & kmask1;
46 scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
47 memcpy(out_scales, scales_u32, 8);
48}
49#endif
50
51void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
52 assert(QK8_0 == 32);
53 assert(k % QK8_0 == 0);
54 const int nb = k / QK8_0;
55
56 block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
57
58#if defined(__ARM_NEON)
59 float32x4_t srcv[4][8];
60 float id[4];
61
62 for (int i = 0; i < nb; i++) {
63 float32x4_t asrcv[8];
64 float32x4_t amaxv[8];
65
66 for (int row_iter = 0; row_iter < 4; row_iter++) {
67 for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
68 for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
69
70 for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
71 for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
72 for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
73
74 const float amax = vmaxvq_f32(amaxv[0]);
75
76 const float d = amax / ((1 << 7) - 1);
77 id[row_iter] = d ? 1.0f / d : 0.0f;
78
79 y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
80 }
81
82 for (int j = 0; j < 8; j++) {
83 float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
84 int32x4_t vi = vcvtnq_s32_f32(v);
85 y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
86 y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
87 y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
88 y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
89
90 v = vmulq_n_f32(srcv[1][j], id[1]);
91 vi = vcvtnq_s32_f32(v);
92 y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
93 y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
94 y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
95 y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
96
97 v = vmulq_n_f32(srcv[2][j], id[2]);
98 vi = vcvtnq_s32_f32(v);
99 y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
100 y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
101 y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
102 y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
103
104 v = vmulq_n_f32(srcv[3][j], id[3]);
105 vi = vcvtnq_s32_f32(v);
106 y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
107 y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
108 y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
109 y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
110 }
111 }
112#else
113 UNUSED(nb);
114 UNUSED(y);
115 ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
116#endif
117}
118
119void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
120 assert(QK8_0 == 32);
121 assert(k % QK8_0 == 0);
122 const int nb = k / QK8_0;
123
124 block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
125
126#if defined(__ARM_NEON)
127 float32x4_t srcv[4][8];
128 float id[4];
129
130 for (int i = 0; i < nb; i++) {
131 float32x4_t asrcv[8];
132 float32x4_t amaxv[8];
133
134 for (int row_iter = 0; row_iter < 4; row_iter++) {
135 for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
136 for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
137
138 for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
139 for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
140 for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
141
142 const float amax = vmaxvq_f32(amaxv[0]);
143
144 const float d = amax / ((1 << 7) - 1);
145 id[row_iter] = d ? 1.0f / d : 0.0f;
146
147 y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
148 }
149
150 for (int j = 0; j < 4; j++) {
151 float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
152 int32x4_t vi = vcvtnq_s32_f32(v);
153 y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
154 y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
155 y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
156 y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
157 v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
158 vi = vcvtnq_s32_f32(v);
159 y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
160 y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
161 y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
162 y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
163
164 v = vmulq_n_f32(srcv[1][2 * j], id[1]);
165 vi = vcvtnq_s32_f32(v);
166 y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
167 y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
168 y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
169 y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
170 v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
171 vi = vcvtnq_s32_f32(v);
172 y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
173 y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
174 y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
175 y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
176
177 v = vmulq_n_f32(srcv[2][2 * j], id[2]);
178 vi = vcvtnq_s32_f32(v);
179 y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
180 y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
181 y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
182 y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
183 v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
184 vi = vcvtnq_s32_f32(v);
185 y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
186 y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
187 y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
188 y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
189
190 v = vmulq_n_f32(srcv[3][2 * j], id[3]);
191 vi = vcvtnq_s32_f32(v);
192 y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
193 y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
194 y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
195 y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
196 v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
197 vi = vcvtnq_s32_f32(v);
198 y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
199 y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
200 y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
201 y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
202 }
203 }
204
205#else
206 UNUSED(nb);
207 UNUSED(y);
208 ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);
209#endif
210}
211
212void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
213 const int qk = QK8_0;
214 const int nb = n / qk;
215 const int ncols_interleaved = 4;
216 const int blocklen = 4;
217
218 assert (n % qk == 0);
219 assert (nc % ncols_interleaved == 0);
220
221 UNUSED(s);
222 UNUSED(bs);
223 UNUSED(vx);
224 UNUSED(vy);
225 UNUSED(nr);
226 UNUSED(nc);
227 UNUSED(nb);
228 UNUSED(ncols_interleaved);
229 UNUSED(blocklen);
230
231#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
232 const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
233
234 for (int c = 0; c < nc; c += ncols_interleaved) {
235 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
236 float32x4_t acc = vdupq_n_f32(0);
237 for (int b = 0; b < nb; b++) {
238 int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
239 int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
240 int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
241 int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
242 float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
243
244 int8x16_t a0 = vld1q_s8(a_ptr->qs);
245 int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
246 float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
247
248 int32x4_t ret = vdupq_n_s32(0);
249
250 ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
251 ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
252 ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
253 ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
254
255 ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
256 ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
257 ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
258 ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
259
260 acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
261 vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
262 a_ptr++;
263 b_ptr++;
264 }
265 vst1q_f32(s, acc);
266 s += ncols_interleaved;
267 }
268 return;
269#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
270 ggml_gemv_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
271}
272
273void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
274 const int qk = QK8_0;
275 const int nb = n / qk;
276 const int ncols_interleaved = 4;
277 const int blocklen = 8;
278
279 assert (n % qk == 0);
280 assert (nc % ncols_interleaved == 0);
281
282 UNUSED(s);
283 UNUSED(bs);
284 UNUSED(vx);
285 UNUSED(vy);
286 UNUSED(nr);
287 UNUSED(nc);
288 UNUSED(nb);
289 UNUSED(ncols_interleaved);
290 UNUSED(blocklen);
291
292#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
293 const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
294
295 for (int c = 0; c < nc; c += ncols_interleaved) {
296 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
297 float32x4_t acc = vdupq_n_f32(0);
298 for (int b = 0; b < nb; b++) {
299 int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
300 int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
301 int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
302 int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
303 float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
304
305 int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
306 int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
307 int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
308 int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
309 float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
310
311 int32x4_t ret0 = vdupq_n_s32(0);
312 int32x4_t ret1 = vdupq_n_s32(0);
313
314 ret0 = vdotq_s32(ret0, b0 << 4, a0);
315 ret1 = vdotq_s32(ret1, b1 << 4, a0);
316 ret0 = vdotq_s32(ret0, b2 << 4, a1);
317 ret1 = vdotq_s32(ret1, b3 << 4, a1);
318
319 ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
320 ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
321 ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
322 ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
323
324 int32x4_t ret = vpaddq_s32(ret0, ret1);
325
326 acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
327 vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
328 a_ptr++;
329 b_ptr++;
330 }
331 vst1q_f32(s, acc);
332 s += ncols_interleaved;
333 }
334 return;
335#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
336 ggml_gemv_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
337}
338
339void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
340 const int qk = QK8_0;
341 const int nb = n / qk;
342 const int ncols_interleaved = 8;
343 const int blocklen = 8;
344
345 assert (n % qk == 0);
346 assert (nc % ncols_interleaved == 0);
347
348 UNUSED(s);
349 UNUSED(bs);
350 UNUSED(vx);
351 UNUSED(vy);
352 UNUSED(nr);
353 UNUSED(nc);
354 UNUSED(nb);
355 UNUSED(ncols_interleaved);
356 UNUSED(blocklen);
357
358#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
359#if defined(__ARM_FEATURE_SVE)
360 if (ggml_cpu_get_sve_cnt() == QK8_0) {
361 const void * b_ptr = vx;
362 const void * a_ptr = vy;
363 float * res_ptr = s;
364
365 __asm__ __volatile__(
366 "ptrue p0.b\n"
367 "add %x[b_ptr], %x[b_ptr], #0x10\n"
368 "1:" // Column loop
369 "add x22, %x[a_ptr], #0x2\n"
370 "mov z31.b, #0x0\n"
371 "mov x21, %x[nb]\n"
372 "2:" // Block loop
373 "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
374 "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
375 "mov z28.s, #0x0\n"
376 "mov z27.s, #0x0\n"
377 "ld1rd { z26.d }, p0/Z, [x22]\n"
378 "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
379 "sub x20, x22, #0x2\n"
380 "sub x21, x21, #0x1\n"
381 "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
382 "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
383 "lsl z22.b, z30.b, #0x4\n"
384 "lsl z16.b, z29.b, #0x4\n"
385 "and z30.b, z30.b, #0xf0\n"
386 "and z29.b, z29.b, #0xf0\n"
387 "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
388 "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
389 "lsl z19.b, z25.b, #0x4\n"
390 "and z25.b, z25.b, #0xf0\n"
391 "ld1rh { z17.h }, p0/Z, [x20]\n"
392 "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
393 "sdot z28.s, z22.b, z26.b\n"
394 "sdot z27.s, z16.b, z26.b\n"
395 "lsl z16.b, z24.b, #0x4\n"
396 "add x22, x22, #0x22\n"
397 "and z24.b, z24.b, #0xf0\n"
398 "add %x[b_ptr], %x[b_ptr], #0x90\n"
399 "fcvt z17.s, p0/m, z17.h\n"
400 "fcvt z18.s, p0/m, z18.h\n"
401 "sdot z28.s, z19.b, z23.b\n"
402 "sdot z27.s, z16.b, z23.b\n"
403 "fmul z18.s, z18.s, z17.s\n"
404 "sdot z28.s, z30.b, z21.b\n"
405 "sdot z27.s, z29.b, z21.b\n"
406 "sdot z28.s, z25.b, z20.b\n"
407 "sdot z27.s, z24.b, z20.b\n"
408 "uzp1 z17.s, z28.s, z27.s\n"
409 "uzp2 z16.s, z28.s, z27.s\n"
410 "add z17.s, z17.s, z16.s\n"
411 "asr z17.s, z17.s, #0x4\n"
412 "scvtf z17.s, p0/m, z17.s\n"
413 "fmla z31.s, p0/M, z17.s, z18.s\n"
414 "cbnz x21, 2b\n"
415 "sub %x[nc], %x[nc], #0x8\n"
416 "st1w { z31.s }, p0, [%x[res_ptr]]\n"
417 "add %x[res_ptr], %x[res_ptr], #0x20\n"
418 "cbnz %x[nc], 1b\n"
419 : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
420 : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
421 : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
422 );
423 return;
424 }
425#endif // #if defined(__ARM_FEATURE_SVE)
426
427#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
428 ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
429}
430
431void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
432 const int qk = QK8_0;
433 const int nb = n / qk;
434 const int ncols_interleaved = 4;
435 const int blocklen = 4;
436
437 assert (n % qk == 0);
438 assert (nc % ncols_interleaved == 0);
439
440 UNUSED(s);
441 UNUSED(bs);
442 UNUSED(vx);
443 UNUSED(vy);
444 UNUSED(nr);
445 UNUSED(nc);
446 UNUSED(nb);
447 UNUSED(ncols_interleaved);
448 UNUSED(blocklen);
449
450#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
451 const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
452 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
453 float * res_ptr = s;
454
455 for (int x = 0; x < nc / ncols_interleaved; x++) {
456 const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
457
458 float32x4_t sumf = vdupq_n_f32(0);
459 for (int l = 0; l < nb; l++) {
460 uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
461 uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
462 uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
463 uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
464
465 int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
466 int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
467 int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
468 int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
469 int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
470 int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
471 int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
472 int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
473
474 int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
475 int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
476
477 int32x4_t sumi = vdupq_n_s32(0);
478 sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
479 sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
480 sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
481 sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
482 sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
483 sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
484 sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
485 sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
486
487 float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
488 float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
489 float32x4_t d = a_d * b_d;
490
491 sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
492 }
493
494 vst1q_f32(res_ptr + x * 4, sumf);
495 }
496 return;
497#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
498 ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
499}
500
501void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
502 constexpr int qk = QK_K;
503 const int nb = n / qk;
504
505 constexpr int ncols_interleaved = 8;
506 constexpr int blocklen = 8;
507
508 assert(n % qk == 0);
509 assert(nc % ncols_interleaved == 0);
510
511 UNUSED(nb);
512 UNUSED(ncols_interleaved);
513 UNUSED(blocklen);
514
515#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
516 constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
517 const uint8x16_t m4b = vdupq_n_u8(0x0f);
518
519 // 1x8 tile = 2 x 4
520 float32x4_t acc_f32[col_groups];
521
522 const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
523
524 for (int x = 0; x < nc / ncols_interleaved; x++) {
525 const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
526
527 for (int i = 0; i < col_groups; i++) {
528 acc_f32[i] = vdupq_n_f32(0);
529 }
530
531 for (int b = 0; b < nb; b++) {
532 float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
533 float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
534 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
535 float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
536 float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
537 float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
538 float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
539 float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
540 float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
541
542 // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
543 int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
544 int32x4_t acc_lo[col_groups];
545 int32x4_t acc_hi[col_groups];
546
547 // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
548 const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
549 int16_t bsums_arr[8];
550 vst1q_s16(bsums_arr, bsums);
551 for (int sb = 0; sb < QK_K / 64; sb++) {
552 for (int i = 0; i < col_groups; i++) {
553 acc_lo[i] = vdupq_n_s32(0);
554 acc_hi[i] = vdupq_n_s32(0);
555 }
556 // Need scales for the low and high nibbles
557 // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
558 int16x8_t q4sb_mins[2];
559 int16x8_t q4sb_scales[2];
560 for (int i = 0; i < 2; i++) {
561 int8_t aux_q4sb[8];
562 const int offset = sb * 24 + i * 12;
563 decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
564 q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
565 }
566
567 int8x16_t q8_qs[64 / 16];
568 for (int i = 0; i < 64 / 16; i++) {
569 q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
570 }
571
572 for (int c = 0; c < col_groups; c++) {
573 uint8x16_t q4_cols[8];
574 for (int i = 0; i < 8; i++) {
575 q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
576 }
577
578 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
579 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
580 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
581 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
582 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
583 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
584 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
585 acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
586
587 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
588 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
589 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
590 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
591 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
592 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
593 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
594 acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
595 }
596
597 // Scales
598 // row c0123 blk0 and blk1
599 const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
600 const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
601 const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
602 vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
603 acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
604 // row c4567 blk0 and blk1
605 const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
606 const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
607 const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
608 vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
609 acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
610
611 // Bias Correction
612 const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
613 const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
614
615 bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
616 bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
617 bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
618 bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
619 } // for sb
620
621 acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
622 acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
623 } // for b
624
625 int base = x * ncols_interleaved;
626 vst1q_f32(s + base, acc_f32[0]);
627 vst1q_f32(s + base + 4, acc_f32[1]);
628 } // for x
629 return;
630#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
631 ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
632}
633
634void ggml_gemv_q4_K_8x8_q8_K(int n,
635 float * GGML_RESTRICT s,
636 size_t bs,
637 const void * GGML_RESTRICT vx,
638 const void * GGML_RESTRICT vy,
639 int nr,
640 int nc) {
641 constexpr int qk = QK_K;
642 const int nb = n / qk;
643
644 constexpr int ncols_interleaved = 8;
645 constexpr int blocklen = 8;
646
647 assert(n % qk == 0);
648 assert(nc % ncols_interleaved == 0);
649
650 UNUSED(nb);
651 UNUSED(ncols_interleaved);
652 UNUSED(blocklen);
653
654#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
655 constexpr int col_pairs = ncols_interleaved / 2;
656 const uint8x16_t m4b = vdupq_n_u8(0x0f);
657
658 // 1x8 tile = 2 x 4
659 float32x4_t acc_f32[ncols_interleaved / 4];
660
661 const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
662
663 for (int x = 0; x < nc / ncols_interleaved; x++) {
664 const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
665
666 for (int i = 0; i < ncols_interleaved / 4; i++) {
667 acc_f32[i] = vdupq_n_f32(0);
668 }
669
670 for (int b = 0; b < nb; b++) {
671 float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
672 float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
673 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
674 float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
675 float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
676 float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
677 float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
678 float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
679 float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
680
681 // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
682 int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
683 // 2 sb each iteration
684 int32x4_t acc_lo[col_pairs];
685 int32x4_t acc_hi[col_pairs];
686
687 // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
688 const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
689 int16_t bsums_arr[8];
690 vst1q_s16(bsums_arr, bsums);
691 for (int sb = 0; sb < QK_K / 64; sb++) {
692 for (int i = 0; i < col_pairs; i++) {
693 acc_lo[i] = vdupq_n_s32(0);
694 acc_hi[i] = vdupq_n_s32(0);
695 }
696 // Need scales for the low and high nibbles
697 // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
698 int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
699 int16x8_t q4sb_scales[2];
700 for (int i = 0; i < 2; i++) {
701 int8_t aux_q4sb[8];
702 const int offset = sb * 24 + i * 12;
703 decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
704 q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
705 }
706
707 const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
708
709 // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
710 // but still need the qs to use the low and hi bits from q4
711 const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
712 int8x16_t q8_qs[8];
713 for (int i = 0; i < 8; i++) {
714 q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
715 }
716
717 // Q4s columns iterated in pairs (01, 23, 45, 67)
718 for (int cp = 0; cp < col_pairs; cp++) {
719 uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
720 uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
721 uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
722 uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
723
724 acc_lo[cp] =
725 ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
726 acc_lo[cp] =
727 ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
728 acc_lo[cp] =
729 ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
730 acc_lo[cp] =
731 ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
732
733 acc_hi[cp] =
734 ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
735 acc_hi[cp] =
736 ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
737 acc_hi[cp] =
738 ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
739 acc_hi[cp] =
740 ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
741 }
742
743 // Iterates over a pair of column pairs (4 columns) to use a single 128 register
744 // p = 0 -> 0123 p2 -> 4567
745 for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
746 int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
747 int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
748 float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
749
750 // 0123 or 4567
751 float32x4_t sumf_0 =
752 vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
753 acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
754
755 float32x4_t sumf_1 =
756 vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
757 acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
758 }
759
760 // Multiply Acc bsum + mins
761 // Each pair of subblocks share the same bsums
762 // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
763 int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
764 int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
765
766 // cols 0-3 bias
767 bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
768 bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
769
770 // cols 4-7 bias
771 bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
772 bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
773 } // for sb
774
775 acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
776 acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
777 } // for b
778
779 int base = x * ncols_interleaved;
780 vst1q_f32(s + base, acc_f32[0]);
781 vst1q_f32(s + base + 4, acc_f32[1]);
782 } // for x
783 return;
784#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
785 ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
786}
787
788void ggml_gemv_q5_K_8x8_q8_K(int n,
789 float * GGML_RESTRICT s,
790 size_t bs,
791 const void * GGML_RESTRICT vx,
792 const void * GGML_RESTRICT vy,
793 int nr,
794 int nc) {
795 constexpr int qk = QK_K;
796 const int nb = n / qk;
797
798 constexpr int ncols_interleaved = 8;
799 constexpr int blocklen = 8;
800
801 assert(n % qk == 0);
802 assert(nc % ncols_interleaved == 0);
803
804 UNUSED(nb);
805 UNUSED(ncols_interleaved);
806 UNUSED(blocklen);
807
808#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
809 constexpr int col_pairs = ncols_interleaved / 2;
810 const uint8x16_t m4b = vdupq_n_u8(0x0f);
811 const uint8x16_t mone = vdupq_n_u8(1);
812 const uint8x16_t mtwo = vdupq_n_u8(2);
813
814 // 1x8 tile = 2 x 4
815 float32x4_t acc_f32[ncols_interleaved / 4];
816
817 const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
818
819 for (int x = 0; x < nc / ncols_interleaved; x++) {
820 const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
821
822 for (int i = 0; i < ncols_interleaved / 4; i++) {
823 acc_f32[i] = vdupq_n_f32(0);
824 }
825
826 for (int b = 0; b < nb; b++) {
827 float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
828 float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
829 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
830 float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
831 float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
832 float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
833 float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
834 float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
835 float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
836
837 // 2 sb each iteration
838 int32x4_t acc_lo[col_pairs];
839 int32x4_t acc_hi[col_pairs];
840
841 // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
842 const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
843 int16_t bsums_arr[8];
844 vst1q_s16(bsums_arr, bsums);
845
846 // Load qh once per block and shift after each subblock
847 const uint8_t * qh_base = q5_ptr[b].qh;
848 uint8x16_t qh[col_pairs][4];
849 for (int cp = 0; cp < col_pairs; cp++) {
850 qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
851 qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
852 qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
853 qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
854 }
855
856 for (int sb = 0; sb < QK_K / 64; sb++) {
857 for (int i = 0; i < col_pairs; i++) {
858 acc_lo[i] = vdupq_n_s32(0);
859 acc_hi[i] = vdupq_n_s32(0);
860 }
861 // Need scales for the low and high nibbles
862 // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
863 int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
864 int16x8_t q5sb_scales[2];
865 for (int i = 0; i < 2; i++) {
866 int8_t aux_q5sb[8];
867 const int offset = sb * 24 + i * 12;
868 decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
869 q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
870 }
871
872 const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
873
874 // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
875 const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
876 int8x16_t q8_qs[8];
877 for (int i = 0; i < 8; i++) {
878 q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
879 }
880
881 // Q5s column pair loop unrolled
882 {
883 // Cols 01
884 uint8x16_t qs_0 = vld1q_u8(qs_base);
885 uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
886 uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
887 uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
888
889 uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
890 uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
891 uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
892 uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
893 uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
894 uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
895 uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
896 uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
897
898 qh[0][0] = vshrq_n_u8(qh[0][0], 2);
899 qh[0][1] = vshrq_n_u8(qh[0][1], 2);
900 qh[0][2] = vshrq_n_u8(qh[0][2], 2);
901 qh[0][3] = vshrq_n_u8(qh[0][3], 2);
902
903 acc_lo[0] = ggml_vdotq_s32(
904 acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
905 acc_lo[0] = ggml_vdotq_s32(
906 acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
907 acc_lo[0] = ggml_vdotq_s32(
908 acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
909 acc_lo[0] = ggml_vdotq_s32(
910 acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
911 acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
912 q8_qs[4]);
913 acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
914 q8_qs[5]);
915 acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
916 q8_qs[6]);
917 acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
918 q8_qs[7]);
919
920 // Cols 23
921 qs_0 = vld1q_u8(qs_base + 16);
922 qs_1 = vld1q_u8(qs_base + 80);
923 qs_2 = vld1q_u8(qs_base + 144);
924 qs_3 = vld1q_u8(qs_base + 208);
925
926 hbit_lo_0 = vandq_u8(qh[1][0], mone);
927 hbit_lo_1 = vandq_u8(qh[1][1], mone);
928 hbit_lo_2 = vandq_u8(qh[1][2], mone);
929 hbit_lo_3 = vandq_u8(qh[1][3], mone);
930 hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
931 hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
932 hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
933 hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
934
935 qh[1][0] = vshrq_n_u8(qh[1][0], 2);
936 qh[1][1] = vshrq_n_u8(qh[1][1], 2);
937 qh[1][2] = vshrq_n_u8(qh[1][2], 2);
938 qh[1][3] = vshrq_n_u8(qh[1][3], 2);
939
940 acc_lo[1] = ggml_vdotq_s32(
941 acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
942 acc_lo[1] = ggml_vdotq_s32(
943 acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
944 acc_lo[1] = ggml_vdotq_s32(
945 acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
946 acc_lo[1] = ggml_vdotq_s32(
947 acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
948 acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
949 q8_qs[4]);
950 acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
951 q8_qs[5]);
952 acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
953 q8_qs[6]);
954 acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
955 q8_qs[7]);
956
957 // Cols 45
958 qs_0 = vld1q_u8(qs_base + 32);
959 qs_1 = vld1q_u8(qs_base + 96);
960 qs_2 = vld1q_u8(qs_base + 160);
961 qs_3 = vld1q_u8(qs_base + 224);
962
963 hbit_lo_0 = vandq_u8(qh[2][0], mone);
964 hbit_lo_1 = vandq_u8(qh[2][1], mone);
965 hbit_lo_2 = vandq_u8(qh[2][2], mone);
966 hbit_lo_3 = vandq_u8(qh[2][3], mone);
967 hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
968 hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
969 hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
970 hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
971
972 qh[2][0] = vshrq_n_u8(qh[2][0], 2);
973 qh[2][1] = vshrq_n_u8(qh[2][1], 2);
974 qh[2][2] = vshrq_n_u8(qh[2][2], 2);
975 qh[2][3] = vshrq_n_u8(qh[2][3], 2);
976
977 acc_lo[2] = ggml_vdotq_s32(
978 acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
979 acc_lo[2] = ggml_vdotq_s32(
980 acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
981 acc_lo[2] = ggml_vdotq_s32(
982 acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
983 acc_lo[2] = ggml_vdotq_s32(
984 acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
985 acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
986 q8_qs[4]);
987 acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
988 q8_qs[5]);
989 acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
990 q8_qs[6]);
991 acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
992 q8_qs[7]);
993
994 // Cols 45
995 qs_0 = vld1q_u8(qs_base + 48);
996 qs_1 = vld1q_u8(qs_base + 112);
997 qs_2 = vld1q_u8(qs_base + 176);
998 qs_3 = vld1q_u8(qs_base + 240);
999
1000 hbit_lo_0 = vandq_u8(qh[3][0], mone);
1001 hbit_lo_1 = vandq_u8(qh[3][1], mone);
1002 hbit_lo_2 = vandq_u8(qh[3][2], mone);
1003 hbit_lo_3 = vandq_u8(qh[3][3], mone);
1004 hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
1005 hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
1006 hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
1007 hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
1008
1009 qh[3][0] = vshrq_n_u8(qh[3][0], 2);
1010 qh[3][1] = vshrq_n_u8(qh[3][1], 2);
1011 qh[3][2] = vshrq_n_u8(qh[3][2], 2);
1012 qh[3][3] = vshrq_n_u8(qh[3][3], 2);
1013
1014 acc_lo[3] = ggml_vdotq_s32(
1015 acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1016 acc_lo[3] = ggml_vdotq_s32(
1017 acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1018 acc_lo[3] = ggml_vdotq_s32(
1019 acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1020 acc_lo[3] = ggml_vdotq_s32(
1021 acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1022 acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1023 q8_qs[4]);
1024 acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1025 q8_qs[5]);
1026 acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1027 q8_qs[6]);
1028 acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1029 q8_qs[7]);
1030 }
1031
1032 // Prepare bsum vectors for bias computation
1033 // Each pair of subblocks share the same bsums
1034 int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
1035 int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
1036
1037 // Iterates over a pair of column pairs (4 columns) to use a single 128 register
1038 // p = 0 -> 0123 p2 -> 4567
1039 for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
1040 int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
1041 int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
1042 int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
1043 int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
1044 float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
1045 float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
1046
1047 // 0123 or 4567
1048 float32x4_t sumf_0 =
1049 vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
1050 acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
1051
1052 float32x4_t sumf_1 =
1053 vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
1054 acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
1055
1056 // FUSED BIAS: Compute and subtract bias immediately
1057 // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
1058 int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
1059 bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
1060 float32x4_t bias_f32 = vcvtq_f32_s32(bias);
1061 acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
1062 }
1063 } // for sb
1064 } // for b
1065
1066 int base = x * ncols_interleaved;
1067 vst1q_f32(s + base, acc_f32[0]);
1068 vst1q_f32(s + base + 4, acc_f32[1]);
1069 } // for x
1070 return;
1071#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1072 ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1073}
1074
1075void ggml_gemv_q6_K_8x4_q8_K(int n,
1076 float * GGML_RESTRICT s,
1077 size_t bs,
1078 const void * GGML_RESTRICT vx,
1079 const void * GGML_RESTRICT vy,
1080 int nr,
1081 int nc) {
1082 constexpr int qk = QK_K;
1083 const int nb = n / qk;
1084
1085 constexpr int ncols_interleaved = 8;
1086 constexpr int blocklen = 4;
1087
1088 assert(n % qk == 0);
1089 assert(nc % ncols_interleaved == 0);
1090
1091 UNUSED(nb);
1092 UNUSED(ncols_interleaved);
1093 UNUSED(blocklen);
1094
1095#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1096 constexpr int col_groups = ncols_interleaved / 4;
1097 const uint8x16_t m4b = vdupq_n_u8(0x0f);
1098 const uint8x16_t mask_lo = vdupq_n_u8(0x03);
1099 const uint8x16_t mask_hi = vdupq_n_u8(0x30);
1100
1101 // 1x8 tile = 2 x 4
1102 float32x4_t acc_f32[2];
1103
1104 const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1105
1106 for (int x = 0; x < nc / ncols_interleaved; x++) {
1107 const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1108
1109 for (int i = 0; i < col_groups; i++) {
1110 acc_f32[i] = vdupq_n_f32(0);
1111 }
1112
1113 for (int b = 0; b < nb; b++) {
1114 float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
1115 float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
1116 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1117 float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1118 float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1119
1120 int32x4_t acc[col_groups];
1121 for (int i = 0; i < col_groups; i++) {
1122 acc[i] = vdupq_n_s32(0);
1123 }
1124
1125 // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1126 // Reused for bias and dequantization later
1127 int16_t q6_scales[16 * 8];
1128 for (int i = 0; i < 16; i++) {
1129 int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1130 vst1q_s16(q6_scales + i * 8, scales);
1131 }
1132
1133 // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1134 int32x4_t bias_lo = vdupq_n_s32(0);
1135 int32x4_t bias_hi = vdupq_n_s32(0);
1136
1137 // Load bsums in chunks of 4 to process with vectorized operations
1138 for (int i = 0; i < 16; i += 4) {
1139 int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
1140 int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1141 int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1142 int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1143 int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1144 int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1145 int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1146 int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1147 int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1148
1149 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1150 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1151 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1152 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1153 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1154 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1155 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1156 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1157 }
1158 bias_lo = vshlq_n_s32(bias_lo, 5);
1159 bias_hi = vshlq_n_s32(bias_hi, 5);
1160
1161 // Process two 128-value halves per superblock
1162 for (int half = 0; half < 2; half++) {
1163 const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1164 const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1165
1166 // A subblock (sb) is a set of weights that share the scale
1167 // Since q6_K scales are per 16 elements
1168 // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1169 for (int sb = 0; sb < QK_K / 64; sb++) {
1170 const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1171 const int8_t * q8_base_h = q8_base_l + 64;
1172
1173 // Load and duplicate q8 values (each register covers four interleaved columns of q6)
1174 int8x16_t q8_l[4];
1175 int8x16_t q8_h[4];
1176 for (int i = 0; i < 4; i++) {
1177 q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
1178 q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
1179 }
1180
1181 const int ql_off_base = sb * QK_K / 2;
1182 const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
1183
1184 // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1185 uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1186 uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1187 uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1188 uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1189
1190 // Adjust qh for subblocks 2 and 3 (shift right by 2)
1191 if (sb > 1) {
1192 q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1193 q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1194 q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1195 q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1196 q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1197 q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1198 q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1199 q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1200 }
1201
1202 const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
1203 q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
1204 const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
1205 q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
1206
1207 // Process column groups (0-3, 4-7)
1208 for (int g = 0; g < col_groups; g++) {
1209 int32x4_t sb_acc_l = vdupq_n_s32(0);
1210 int32x4_t sb_acc_h = vdupq_n_s32(0);
1211
1212 for (int chunk = 0; chunk < 4; chunk++) {
1213 const int idx = chunk * 2 + g;
1214
1215 const uint8x16_t q6_qs_l = q6_ql[idx];
1216 const uint8x16_t q6_qs_h = q6_qh[idx];
1217
1218 // Extract high 2 bits for upper nibble reconstruction
1219 const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
1220
1221 // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1222 const int8x16_t q6_l =
1223 vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
1224 const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
1225
1226 sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
1227 sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
1228 }
1229
1230 const int scale_idx_l = half * 8 + sb;
1231 const int scale_idx_h = half * 8 + sb + 4;
1232
1233 const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
1234 const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
1235
1236 acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
1237 acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
1238 }
1239 }
1240 } // for half
1241
1242 // Bias correction
1243 acc[0] = vsubq_s32(acc[0], bias_lo);
1244 acc[1] = vsubq_s32(acc[1], bias_hi);
1245
1246 // Apply superblock scale (no mins for q6_K)
1247 // acc[g] has [c0, c1, c2, c3]
1248 float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
1249 float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
1250
1251 acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
1252 acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
1253 } // for b
1254
1255 int base = x * ncols_interleaved;
1256 vst1q_f32(s + base, acc_f32[0]);
1257 vst1q_f32(s + base + 4, acc_f32[1]);
1258 } // for x
1259 return;
1260#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1261 ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1262}
1263
1264void ggml_gemv_q6_K_8x8_q8_K(int n,
1265 float * GGML_RESTRICT s,
1266 size_t bs,
1267 const void * GGML_RESTRICT vx,
1268 const void * GGML_RESTRICT vy,
1269 int nr,
1270 int nc) {
1271 constexpr int qk = QK_K;
1272 const int nb = n / qk;
1273
1274 constexpr int ncols_interleaved = 8;
1275 constexpr int blocklen = 8;
1276
1277 assert(n % qk == 0);
1278 assert(nc % ncols_interleaved == 0);
1279
1280 UNUSED(nb);
1281 UNUSED(ncols_interleaved);
1282 UNUSED(blocklen);
1283
1284#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1285 constexpr int col_pairs = ncols_interleaved / 2;
1286 const uint8x16_t m4b = vdupq_n_u8(0x0f);
1287 const uint8x16_t mask_lo = vdupq_n_u8(0x03);
1288 const uint8x16_t mask_hi = vdupq_n_u8(0x30);
1289
1290 // 1x8 tile = 2 x 4
1291 float32x4_t acc_f32[2];
1292
1293 const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1294
1295 for (int x = 0; x < nc / ncols_interleaved; x++) {
1296 const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1297
1298 acc_f32[0] = vdupq_n_f32(0);
1299 acc_f32[1] = vdupq_n_f32(0);
1300
1301 for (int b = 0; b < nb; b++) {
1302 float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
1303 float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
1304 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
1305 float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1306 float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1307
1308 int32x2_t acc[col_pairs];
1309 for (int i = 0; i < col_pairs; i++) {
1310 acc[i] = vdup_n_s32(0);
1311 }
1312
1313 // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1314 // Reused for bias and dequantization later
1315 int16_t q6_scales[16 * 8];
1316 for (int i = 0; i < 16; i++) {
1317 int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1318 vst1q_s16(q6_scales + i * 8, scales);
1319 }
1320
1321 // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1322 int32x4_t bias_lo = vdupq_n_s32(0);
1323 int32x4_t bias_hi = vdupq_n_s32(0);
1324
1325 // Load bsums in chunks of 4 to process with vectorized operations
1326 for (int i = 0; i < 16; i += 4) {
1327 int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
1328 int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1329 int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1330 int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1331 int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1332 int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1333 int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1334 int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1335 int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1336
1337 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1338 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1339 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1340 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1341 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1342 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1343 bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1344 bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1345 }
1346 bias_lo = vshlq_n_s32(bias_lo, 5);
1347 bias_hi = vshlq_n_s32(bias_hi, 5);
1348
1349 // Process two 128-value halves per superblock
1350 for (int half = 0; half < 2; half++) {
1351 const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1352 const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1353
1354 // A subblock (sb) is a set of weights that share the scale
1355 // Since q6_K scales are per 16 elements
1356 // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1357 for (int sb = 0; sb < QK_K / 64; sb++) {
1358 const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1359 const int8_t * q8_base_h = q8_base_l + 64;
1360
1361 // Load and duplicate q8 values (each register covers two interleaved columns of q6)
1362 int8x16_t q8_l[2];
1363 int8x16_t q8_h[2];
1364 for (int i = 0; i < 2; i++) {
1365 q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
1366 q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
1367 }
1368
1369 const int ql_off_base = sb * QK_K / 2;
1370 const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
1371
1372 // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1373 uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1374 uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1375 uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1376 uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1377
1378 // Adjust qh for subblocks 2 and 3 (shift right by 2)
1379 if (sb > 1) {
1380 q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1381 q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1382 q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1383 q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1384 q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1385 q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1386 q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1387 q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1388 }
1389
1390 // Process column pairs (0-1, 2-3, 4-5, 6-7)
1391 for (int cp = 0; cp < col_pairs; cp++) {
1392 const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
1393 const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
1394 const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
1395 const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
1396
1397 // Extract high 2 bits for upper nibble reconstruction
1398 const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
1399 const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
1400
1401 // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1402 const int8x16_t q6_l0 = vreinterpretq_s8_u8(
1403 vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
1404 const int8x16_t q6_l1 = vreinterpretq_s8_u8(
1405 vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
1406 const int8x16_t q6_h0 =
1407 vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
1408 const int8x16_t q6_h1 =
1409 vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
1410
1411 int32x4_t sb_acc_l = vdupq_n_s32(0);
1412 sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
1413 sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
1414
1415 int32x4_t sb_acc_h = vdupq_n_s32(0);
1416 sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
1417 sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
1418
1419 // Pairwise add to get per-column sums: [col0, col1]
1420 int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
1421 int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
1422
1423 const int scale_idx_l = half * 8 + sb;
1424 const int scale_idx_h = half * 8 + sb + 4;
1425
1426 // Access scales using array indexing (scales are interleaved by column)
1427 const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
1428 (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
1429 const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
1430 (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
1431
1432 // Accumulate scaled results
1433 acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
1434 acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
1435 }
1436 }
1437 } // for half
1438
1439 // Bias correction
1440 acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
1441 acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
1442 acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
1443 acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
1444
1445 // Apply superblock scale (no mins for q6_K)
1446 // acc[cp] has [c0, c1]
1447 float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
1448 float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
1449 float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
1450 float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
1451
1452 acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
1453 acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
1454 } // for b
1455
1456 int base = x * ncols_interleaved;
1457 vst1q_f32(s + base, acc_f32[0]);
1458 vst1q_f32(s + base + 4, acc_f32[1]);
1459 } // for x
1460 return;
1461#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1462 ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1463}
1464
1465void ggml_gemv_q8_0_4x4_q8_0(int n,
1466 float * GGML_RESTRICT s,
1467 size_t bs,
1468 const void * GGML_RESTRICT vx,
1469 const void * GGML_RESTRICT vy,
1470 int nr,
1471 int nc) {
1472 const int qk = QK8_0;
1473 const int nb = n / qk;
1474 const int ncols_interleaved = 4;
1475 const int blocklen = 4;
1476
1477 assert(n % qk == 0);
1478 assert(nc % ncols_interleaved == 0);
1479
1480 UNUSED(nb);
1481 UNUSED(ncols_interleaved);
1482 UNUSED(blocklen);
1483
1484#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1485 const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1486
1487 for (int c = 0; c < nc; c += ncols_interleaved) {
1488 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1489 float32x4_t acc = vdupq_n_f32(0);
1490 for (int b = 0; b < nb; b++) {
1491 int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1492 int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1493 float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1494
1495 int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
1496 float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1497
1498 int32x4_t ret = vdupq_n_s32(0);
1499
1500 ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
1501 ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
1502 ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
1503 ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
1504
1505 ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
1506 ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
1507 ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
1508 ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
1509
1510 acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1511 a_ptr++;
1512 b_ptr++;
1513 }
1514 vst1q_f32(s, acc);
1515 s += ncols_interleaved;
1516 }
1517 return;
1518
1519#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1520 ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1521}
1522
1523void ggml_gemv_q8_0_4x8_q8_0(int n,
1524 float * GGML_RESTRICT s,
1525 size_t bs,
1526 const void * GGML_RESTRICT vx,
1527 const void * GGML_RESTRICT vy,
1528 int nr,
1529 int nc) {
1530 const int qk = QK8_0;
1531 const int nb = n / qk;
1532 const int ncols_interleaved = 4;
1533 const int blocklen = 8;
1534
1535 assert(n % qk == 0);
1536 assert(nc % ncols_interleaved == 0);
1537
1538 UNUSED(nb);
1539 UNUSED(ncols_interleaved);
1540 UNUSED(blocklen);
1541
1542#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1543 const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1544
1545 for (int c = 0; c < nc; c += ncols_interleaved) {
1546 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1547 float32x4_t acc = vdupq_n_f32(0);
1548
1549 for (int b = 0; b < nb; b++) {
1550 int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1551 int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1552 float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
1553
1554 int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
1555 int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
1556 int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
1557 int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
1558 int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
1559 float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1560
1561 int32x4_t ret0 = vdupq_n_s32(0);
1562 int32x4_t ret1 = vdupq_n_s32(0);
1563
1564 // 0..7
1565 ret0 = vdotq_s32(ret0, b_low.val[0], a0);
1566 ret1 = vdotq_s32(ret1, b_low.val[1], a0);
1567 // 8..15
1568 ret0 = vdotq_s32(ret0, b_low.val[2], a1);
1569 ret1 = vdotq_s32(ret1, b_low.val[3], a1);
1570 // 16..23
1571 ret0 = vdotq_s32(ret0, b_high.val[0], a2);
1572 ret1 = vdotq_s32(ret1, b_high.val[1], a2);
1573 // 24..31
1574 ret0 = vdotq_s32(ret0, b_high.val[2], a3);
1575 ret1 = vdotq_s32(ret1, b_high.val[3], a3);
1576
1577 int32x4_t ret = vpaddq_s32(ret0, ret1);
1578
1579 acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1580 a_ptr++;
1581 b_ptr++;
1582 }
1583 vst1q_f32(s, acc);
1584 s += ncols_interleaved;
1585 }
1586 return;
1587
1588#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1589 ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1590}
1591
1592void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1593 const int qk = QK8_0;
1594 const int nb = n / qk;
1595 const int ncols_interleaved = 4;
1596 const int blocklen = 4;
1597
1598 assert (n % qk == 0);
1599 assert (nr % 4 == 0);
1600 assert (nc % ncols_interleaved == 0);
1601
1602 UNUSED(s);
1603 UNUSED(bs);
1604 UNUSED(vx);
1605 UNUSED(vy);
1606 UNUSED(nr);
1607 UNUSED(nc);
1608 UNUSED(nb);
1609 UNUSED(ncols_interleaved);
1610 UNUSED(blocklen);
1611
1612#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1613 const void * b_ptr = vx;
1614 const void * a_ptr = vy;
1615 float * res_ptr = s;
1616 size_t res_stride = bs * sizeof(float);
1617
1618 __asm__ __volatile__(
1619 "mov x10, %x[nr]\n"
1620 "mov x9, #0x88\n"
1621 "cmp x10, #0x10\n"
1622 "mul x9, %x[nb], x9\n"
1623 "blt 4f\n"
1624 "1:" // Row loop
1625 "add x28, %x[b_ptr], #0x8\n"
1626 "mov x27, %x[nc]\n"
1627 "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
1628 "2:" // Column loop
1629 "add x25, %x[a_ptr], #0x8\n"
1630 "movi v15.16b, #0x0\n"
1631 "movi v19.16b, #0x0\n"
1632 "mov x24, %x[nb]\n"
1633 "add x23, x25, x9\n"
1634 "movi v18.16b, #0x0\n"
1635 "movi v14.16b, #0x0\n"
1636 "add x22, x23, x9\n"
1637 "movi v11.16b, #0x0\n"
1638 "movi v13.16b, #0x0\n"
1639 "add x21, x22, x9\n"
1640 "movi v23.16b, #0x0\n"
1641 "movi v16.16b, #0x0\n"
1642 "movi v25.16b, #0x0\n"
1643 "movi v7.16b, #0x0\n"
1644 "movi v0.16b, #0x0\n"
1645 "movi v4.16b, #0x0\n"
1646 "movi v5.16b, #0x0\n"
1647 "movi v21.16b, #0x0\n"
1648 "movi v8.16b, #0x0\n"
1649 "movi v1.16b, #0x0\n"
1650 "3:" // Block loop
1651 "ldr q3, [x28, #0x0]\n"
1652 "ldr q31, [x25, #0x0]\n"
1653 "movi v28.16b, #0x4\n"
1654 "movi v10.4s, #0x0\n"
1655 "ldr q22, [x28, #0x10]\n"
1656 "ldr q6, [x25, #0x10]\n"
1657 "movi v29.4s, #0x0\n"
1658 "movi v9.4s, #0x0\n"
1659 "ldr q27, [x28, #0x20]\n"
1660 "ldr q30, [x28, #0x30]\n"
1661 "movi v20.4s, #0x0\n"
1662 "movi v24.16b, #0xf0\n"
1663 "ldr d2, [x25, #-0x8]\n"
1664 "ldr d26, [x23, #-0x8]\n"
1665 "sshl v12.16b, v3.16b, v28.16b\n"
1666 "sub x20, x28, #0x8\n"
1667 "ldr d17, [x20, #0x0]\n"
1668 "and v3.16b, v3.16b, v24.16b\n"
1669 "subs x24, x24, #0x1\n"
1670 "add x28, x28, #0x48\n"
1671 ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
1672 ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
1673 ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
1674 ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
1675 "sshl v31.16b, v22.16b, v28.16b\n"
1676 "and v22.16b, v22.16b, v24.16b\n"
1677 "fcvtl v17.4s, v17.4h\n"
1678 "fcvtl v2.4s, v2.4h\n"
1679 "fcvtl v26.4s, v26.4h\n"
1680 ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
1681 ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
1682 ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
1683 ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
1684 "sshl v6.16b, v27.16b, v28.16b\n"
1685 "sshl v28.16b, v30.16b, v28.16b\n"
1686 "and v27.16b, v27.16b, v24.16b\n"
1687 "and v30.16b, v30.16b, v24.16b\n"
1688 "ldr q24, [x25, #0x20]\n"
1689 ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
1690 ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1691 ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
1692 ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
1693 "ldr q24, [x25, #0x30]\n"
1694 ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
1695 ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
1696 ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
1697 ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
1698 "ldr q24, [x25, #0x40]\n"
1699 ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
1700 ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1701 ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
1702 ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
1703 "ldr q24, [x25, #0x50]\n"
1704 ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
1705 ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
1706 ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
1707 ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
1708 "ldr q24, [x25, #0x60]\n"
1709 ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
1710 ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1711 ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
1712 ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
1713 "ldr q24, [x25, #0x70]\n"
1714 "add x25, x25, #0x88\n"
1715 ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
1716 ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
1717 ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
1718 ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
1719 "fmul v24.4s, v17.4s, v2.s[0]\n"
1720 "scvtf v10.4s, v10.4s, #0x4\n"
1721 "scvtf v29.4s, v29.4s, #0x4\n"
1722 "scvtf v9.4s, v9.4s, #0x4\n"
1723 "scvtf v20.4s, v20.4s, #0x4\n"
1724 "fmla v15.4s, v10.4s, v24.4s\n"
1725 "ldr q24, [x23, #0x0]\n"
1726 "fmul v10.4s, v17.4s, v2.s[1]\n"
1727 "fmla v19.4s, v29.4s, v10.4s\n"
1728 "ldr q10, [x23, #0x10]\n"
1729 "fmul v29.4s, v17.4s, v2.s[2]\n"
1730 "fmul v2.4s, v17.4s, v2.s[3]\n"
1731 "fmla v18.4s, v9.4s, v29.4s\n"
1732 "movi v9.4s, #0x0\n"
1733 "movi v29.4s, #0x0\n"
1734 ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
1735 ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
1736 "fmla v14.4s, v20.4s, v2.4s\n"
1737 "movi v20.4s, #0x0\n"
1738 "movi v2.4s, #0x0\n"
1739 ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
1740 ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1741 "ldr q24, [x23, #0x20]\n"
1742 ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
1743 ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
1744 ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
1745 ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
1746 "ldr q10, [x23, #0x30]\n"
1747 ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
1748 ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
1749 ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
1750 ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
1751 "ldr q24, [x23, #0x40]\n"
1752 ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
1753 ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
1754 ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
1755 ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
1756 "ldr q10, [x23, #0x50]\n"
1757 ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
1758 ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
1759 ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
1760 ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
1761 "ldr q24, [x23, #0x60]\n"
1762 ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
1763 ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
1764 ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
1765 ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
1766 "ldr q10, [x23, #0x70]\n"
1767 "add x23, x23, #0x88\n"
1768 ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
1769 ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
1770 ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
1771 ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
1772 "ldr q24, [x22, #0x0]\n"
1773 ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
1774 ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
1775 ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
1776 ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
1777 "fmul v10.4s, v17.4s, v26.s[0]\n"
1778 "scvtf v9.4s, v9.4s, #0x4\n"
1779 "scvtf v29.4s, v29.4s, #0x4\n"
1780 "scvtf v20.4s, v20.4s, #0x4\n"
1781 "scvtf v2.4s, v2.4s, #0x4\n"
1782 "fmla v11.4s, v9.4s, v10.4s\n"
1783 "ldr q9, [x22, #0x10]\n"
1784 "fmul v10.4s, v17.4s, v26.s[1]\n"
1785 "fmla v13.4s, v29.4s, v10.4s\n"
1786 "ldr d29, [x22, #-0x8]\n"
1787 "fmul v10.4s, v17.4s, v26.s[2]\n"
1788 "fmul v26.4s, v17.4s, v26.s[3]\n"
1789 "fcvtl v29.4s, v29.4h\n"
1790 "fmla v23.4s, v20.4s, v10.4s\n"
1791 "movi v20.4s, #0x0\n"
1792 "movi v10.4s, #0x0\n"
1793 "fmla v16.4s, v2.4s, v26.4s\n"
1794 "movi v26.4s, #0x0\n"
1795 "movi v2.4s, #0x0\n"
1796 ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
1797 ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
1798 ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
1799 ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
1800 "ldr q24, [x22, #0x20]\n"
1801 ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
1802 ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
1803 ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
1804 ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
1805 "ldr q9, [x22, #0x30]\n"
1806 ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
1807 ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
1808 ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
1809 ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
1810 "ldr q24, [x22, #0x40]\n"
1811 ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
1812 ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
1813 ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
1814 ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
1815 "ldr q9, [x22, #0x50]\n"
1816 ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
1817 ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
1818 ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
1819 ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
1820 "ldr q24, [x22, #0x60]\n"
1821 ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
1822 ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
1823 ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
1824 ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
1825 "ldr q9, [x22, #0x70]\n"
1826 "add x22, x22, #0x88\n"
1827 ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
1828 ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
1829 ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
1830 ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
1831 "ldr q24, [x21, #0x0]\n"
1832 ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
1833 ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
1834 ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
1835 ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
1836 "fmul v9.4s, v17.4s, v29.s[0]\n"
1837 "scvtf v20.4s, v20.4s, #0x4\n"
1838 "scvtf v10.4s, v10.4s, #0x4\n"
1839 "scvtf v26.4s, v26.4s, #0x4\n"
1840 "scvtf v2.4s, v2.4s, #0x4\n"
1841 "fmla v25.4s, v20.4s, v9.4s\n"
1842 "ldr q9, [x21, #0x10]\n"
1843 "fmul v20.4s, v17.4s, v29.s[1]\n"
1844 "fmla v7.4s, v10.4s, v20.4s\n"
1845 "ldr d20, [x21, #-0x8]\n"
1846 "fmul v10.4s, v17.4s, v29.s[2]\n"
1847 "fmul v29.4s, v17.4s, v29.s[3]\n"
1848 "fcvtl v20.4s, v20.4h\n"
1849 "fmla v0.4s, v26.4s, v10.4s\n"
1850 "movi v26.4s, #0x0\n"
1851 "movi v10.4s, #0x0\n"
1852 "fmla v4.4s, v2.4s, v29.4s\n"
1853 "movi v2.4s, #0x0\n"
1854 "movi v29.4s, #0x0\n"
1855 ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
1856 ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
1857 ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
1858 ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
1859 "ldr q12, [x21, #0x20]\n"
1860 "fmul v24.4s, v17.4s, v20.s[0]\n"
1861 ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
1862 ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
1863 ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
1864 ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
1865 "ldr q9, [x21, #0x30]\n"
1866 "fmul v31.4s, v17.4s, v20.s[1]\n"
1867 ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
1868 ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
1869 ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
1870 ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
1871 "ldr q12, [x21, #0x40]\n"
1872 "fmul v6.4s, v17.4s, v20.s[2]\n"
1873 "fmul v20.4s, v17.4s, v20.s[3]\n"
1874 ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
1875 ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
1876 ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
1877 ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
1878 "ldr q9, [x21, #0x50]\n"
1879 ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
1880 ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
1881 ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
1882 ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
1883 "ldr q12, [x21, #0x60]\n"
1884 ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
1885 ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
1886 ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
1887 ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
1888 "ldr q17, [x21, #0x70]\n"
1889 "add x21, x21, #0x88\n"
1890 ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
1891 ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
1892 ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
1893 ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
1894 ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
1895 ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
1896 ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
1897 ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
1898 "scvtf v26.4s, v26.4s, #0x4\n"
1899 "scvtf v10.4s, v10.4s, #0x4\n"
1900 "fmla v5.4s, v26.4s, v24.4s\n"
1901 "scvtf v2.4s, v2.4s, #0x4\n"
1902 "scvtf v29.4s, v29.4s, #0x4\n"
1903 "fmla v21.4s, v10.4s, v31.4s\n"
1904 "fmla v8.4s, v2.4s, v6.4s\n"
1905 "fmla v1.4s, v29.4s, v20.4s\n"
1906 "bgt 3b\n"
1907 "mov x20, %x[res_ptr]\n"
1908 "subs x27, x27, #0x4\n"
1909 "add %x[res_ptr], %x[res_ptr], #0x10\n"
1910 "str q15, [x20, #0x0]\n"
1911 "add x20, x20, %x[res_stride]\n"
1912 "str q19, [x20, #0x0]\n"
1913 "add x20, x20, %x[res_stride]\n"
1914 "str q18, [x20, #0x0]\n"
1915 "add x20, x20, %x[res_stride]\n"
1916 "str q14, [x20, #0x0]\n"
1917 "add x20, x20, %x[res_stride]\n"
1918 "str q11, [x20, #0x0]\n"
1919 "add x20, x20, %x[res_stride]\n"
1920 "str q13, [x20, #0x0]\n"
1921 "add x20, x20, %x[res_stride]\n"
1922 "str q23, [x20, #0x0]\n"
1923 "add x20, x20, %x[res_stride]\n"
1924 "str q16, [x20, #0x0]\n"
1925 "add x20, x20, %x[res_stride]\n"
1926 "str q25, [x20, #0x0]\n"
1927 "add x20, x20, %x[res_stride]\n"
1928 "str q7, [x20, #0x0]\n"
1929 "add x20, x20, %x[res_stride]\n"
1930 "str q0, [x20, #0x0]\n"
1931 "add x20, x20, %x[res_stride]\n"
1932 "str q4, [x20, #0x0]\n"
1933 "add x20, x20, %x[res_stride]\n"
1934 "str q5, [x20, #0x0]\n"
1935 "add x20, x20, %x[res_stride]\n"
1936 "str q21, [x20, #0x0]\n"
1937 "add x20, x20, %x[res_stride]\n"
1938 "str q8, [x20, #0x0]\n"
1939 "add x20, x20, %x[res_stride]\n"
1940 "str q1, [x20, #0x0]\n"
1941 "bne 2b\n"
1942 "mov x20, #0x4\n"
1943 "sub x10, x10, #0x10\n"
1944 "cmp x10, #0x10\n"
1945 "mov %x[res_ptr], x26\n"
1946 "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
1947 "bge 1b\n"
1948 "4:" // Row loop skip
1949 "cbz x10, 9f\n"
1950 "5:" // Row tail: Row loop
1951 "add x24, %x[b_ptr], #0x8\n"
1952 "mov x23, %x[nc]\n"
1953 "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
1954 "6:" // Row tail: Column loop
1955 "movi v15.16b, #0x0\n"
1956 "movi v19.16b, #0x0\n"
1957 "add x25, %x[a_ptr], #0x8\n"
1958 "mov x21, %x[nb]\n"
1959 "movi v18.16b, #0x0\n"
1960 "movi v14.16b, #0x0\n"
1961 "7:" // Row tail: Block loop
1962 "ldr q7, [x24, #0x0]\n"
1963 "ldr q5, [x25, #0x0]\n"
1964 "movi v9.16b, #0x4\n"
1965 "movi v4.4s, #0x0\n"
1966 "ldr q3, [x24, #0x10]\n"
1967 "ldr q2, [x25, #0x10]\n"
1968 "movi v1.4s, #0x0\n"
1969 "movi v0.4s, #0x0\n"
1970 "ldr q13, [x24, #0x20]\n"
1971 "ldr q31, [x25, #0x20]\n"
1972 "movi v30.4s, #0x0\n"
1973 "movi v29.16b, #0xf0\n"
1974 "ldr q28, [x24, #0x30]\n"
1975 "ldr q27, [x25, #0x30]\n"
1976 "sshl v20.16b, v7.16b, v9.16b\n"
1977 "sub x20, x24, #0x8\n"
1978 "ldr q26, [x25, #0x40]\n"
1979 "ldr q25, [x25, #0x50]\n"
1980 "sshl v17.16b, v3.16b, v9.16b\n"
1981 "and v7.16b, v7.16b, v29.16b\n"
1982 "ldr q24, [x25, #0x60]\n"
1983 "ldr q16, [x25, #0x70]\n"
1984 "sshl v22.16b, v13.16b, v9.16b\n"
1985 "and v3.16b, v3.16b, v29.16b\n"
1986 "ldr d21, [x20, #0x0]\n"
1987 "ldr d12, [x25, #-0x8]\n"
1988 ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
1989 ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
1990 ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
1991 ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
1992 "sshl v9.16b, v28.16b, v9.16b\n"
1993 "subs x21, x21, #0x1\n"
1994 "and v13.16b, v13.16b, v29.16b\n"
1995 "and v28.16b, v28.16b, v29.16b\n"
1996 "add x25, x25, #0x88\n"
1997 "add x24, x24, #0x48\n"
1998 "fcvtl v21.4s, v21.4h\n"
1999 "fcvtl v12.4s, v12.4h\n"
2000 ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
2001 ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
2002 ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
2003 ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
2004 "fmul v11.4s, v21.4s, v12.s[0]\n"
2005 "fmul v23.4s, v21.4s, v12.s[1]\n"
2006 "fmul v17.4s, v21.4s, v12.s[2]\n"
2007 ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
2008 "fmul v6.4s, v21.4s, v12.s[3]\n"
2009 ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
2010 ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
2011 ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
2012 ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
2013 ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
2014 ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
2015 ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
2016 ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
2017 ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
2018 ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
2019 ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
2020 ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
2021 ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
2022 ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
2023 ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
2024 ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
2025 ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
2026 ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
2027 ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
2028 ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
2029 ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
2030 ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
2031 ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
2032 "scvtf v4.4s, v4.4s, #0x4\n"
2033 "scvtf v1.4s, v1.4s, #0x4\n"
2034 "scvtf v0.4s, v0.4s, #0x4\n"
2035 "fmla v15.4s, v4.4s, v11.4s\n"
2036 "scvtf v30.4s, v30.4s, #0x4\n"
2037 "fmla v19.4s, v1.4s, v23.4s\n"
2038 "fmla v18.4s, v0.4s, v17.4s\n"
2039 "fmla v14.4s, v30.4s, v6.4s\n"
2040 "bgt 7b\n"
2041 "mov x20, %x[res_ptr]\n"
2042 "cmp x10, #0x1\n"
2043 "str q15, [x20, #0x0]\n"
2044 "add x20, x20, %x[res_stride]\n"
2045 "ble 8f\n"
2046 "cmp x10, #0x2\n"
2047 "str q19, [x20, #0x0]\n"
2048 "add x20, x20, %x[res_stride]\n"
2049 "ble 8f\n"
2050 "cmp x10, #0x3\n"
2051 "str q18, [x20, #0x0]\n"
2052 "add x20, x20, %x[res_stride]\n"
2053 "ble 8f\n"
2054 "str q14, [x20, #0x0]\n"
2055 "8:" // Row tail: Accumulator store skip
2056 "subs x23, x23, #0x4\n"
2057 "add %x[res_ptr], %x[res_ptr], #0x10\n"
2058 "bne 6b\n"
2059 "subs x10, x10, #0x4\n"
2060 "add %x[a_ptr], %x[a_ptr], x9\n"
2061 "mov %x[res_ptr], x22\n"
2062 "bgt 5b\n"
2063 "9:" // Row tail: Row loop skip
2064 : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2065 : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2066 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
2067 );
2068 return;
2069#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
2070 ggml_gemm_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2071}
2072
2073void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2074 const int qk = QK8_0;
2075 const int nb = n / qk;
2076 const int ncols_interleaved = 4;
2077 const int blocklen = 8;
2078
2079 assert (n % qk == 0);
2080 assert (nr % 4 == 0);
2081 assert (nc % ncols_interleaved == 0);
2082
2083 UNUSED(s);
2084 UNUSED(bs);
2085 UNUSED(vx);
2086 UNUSED(vy);
2087 UNUSED(nr);
2088 UNUSED(nc);
2089 UNUSED(nb);
2090 UNUSED(ncols_interleaved);
2091 UNUSED(blocklen);
2092
2093#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2094 const void * b_ptr = vx;
2095 const void * a_ptr = vy;
2096 float * res_ptr = s;
2097 size_t res_stride = bs * sizeof(float);
2098
2099 __asm__ __volatile__(
2100 "mov x10, %x[nr]\n"
2101 "mov x9, #0x88\n"
2102 "cmp x10, #0x10\n"
2103 "mul x9, %x[nb], x9\n"
2104 "blt 4f\n"
2105 "1:" // Row loop
2106 "add x28, %x[b_ptr], #0x8\n"
2107 "mov x27, %x[nc]\n"
2108 "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
2109 "2:" // Column loop
2110 "add x25, %x[a_ptr], #0x8\n"
2111 "movi v2.16b, #0x0\n"
2112 "movi v10.16b, #0x0\n"
2113 "mov x24, %x[nb]\n"
2114 "add x23, x25, x9\n"
2115 "movi v12.16b, #0x0\n"
2116 "movi v28.16b, #0x0\n"
2117 "add x22, x23, x9\n"
2118 "movi v11.16b, #0x0\n"
2119 "movi v13.16b, #0x0\n"
2120 "add x21, x22, x9\n"
2121 "movi v22.16b, #0x0\n"
2122 "movi v23.16b, #0x0\n"
2123 "movi v25.16b, #0x0\n"
2124 "movi v5.16b, #0x0\n"
2125 "movi v7.16b, #0x0\n"
2126 "movi v4.16b, #0x0\n"
2127 "movi v6.16b, #0x0\n"
2128 "movi v30.16b, #0x0\n"
2129 "movi v24.16b, #0x0\n"
2130 "movi v14.16b, #0x0\n"
2131 "3:" // Block loop
2132 "ldr q21, [x28, #0x0]\n"
2133 "ldr q16, [x28, #0x10]\n"
2134 "movi v1.16b, #0x4\n"
2135 "movi v19.4s, #0x0\n"
2136 "ldr q27, [x25, #0x0]\n"
2137 "ldr q15, [x25, #0x10]\n"
2138 "movi v26.4s, #0x0\n"
2139 "movi v18.4s, #0x0\n"
2140 "ldr q29, [x28, #0x20]\n"
2141 "ldr q3, [x28, #0x30]\n"
2142 "movi v17.4s, #0x0\n"
2143 "movi v0.16b, #0xf0\n"
2144 "ldr d20, [x25, #-0x8]\n"
2145 "ldr d9, [x23, #-0x8]\n"
2146 "sshl v8.16b, v21.16b, v1.16b\n"
2147 "sshl v31.16b, v16.16b, v1.16b\n"
2148 "and v21.16b, v21.16b, v0.16b\n"
2149 "and v16.16b, v16.16b, v0.16b\n"
2150 "sub x20, x28, #0x8\n"
2151 "subs x24, x24, #0x1\n"
2152 "add x28, x28, #0x48\n"
2153 ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
2154 ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
2155 "ldr q27, [x25, #0x20]\n"
2156 ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
2157 ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
2158 "sshl v15.16b, v29.16b, v1.16b\n"
2159 "sshl v1.16b, v3.16b, v1.16b\n"
2160 "and v29.16b, v29.16b, v0.16b\n"
2161 "and v3.16b, v3.16b, v0.16b\n"
2162 "ldr q0, [x25, #0x30]\n"
2163 "fcvtl v20.4s, v20.4h\n"
2164 ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
2165 "fcvtl v9.4s, v9.4h\n"
2166 ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
2167 "ldr q27, [x25, #0x40]\n"
2168 ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
2169 ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
2170 "ldr q0, [x25, #0x50]\n"
2171 ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
2172 ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
2173 "ldr q27, [x25, #0x60]\n"
2174 ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
2175 ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
2176 "ldr q0, [x25, #0x70]\n"
2177 "add x25, x25, #0x88\n"
2178 ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
2179 ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
2180 "ldr d27, [x20, #0x0]\n"
2181 ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
2182 ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
2183 "fcvtl v27.4s, v27.4h\n"
2184 "uzp1 v0.2d, v19.2d, v26.2d\n"
2185 "uzp2 v26.2d, v19.2d, v26.2d\n"
2186 "fmul v19.4s, v27.4s, v20.s[0]\n"
2187 "scvtf v0.4s, v0.4s, #0x4\n"
2188 "scvtf v26.4s, v26.4s, #0x4\n"
2189 "fmla v2.4s, v0.4s, v19.4s\n"
2190 "ldr q19, [x23, #0x0]\n"
2191 "uzp1 v0.2d, v18.2d, v17.2d\n"
2192 "uzp2 v18.2d, v18.2d, v17.2d\n"
2193 "fmul v17.4s, v27.4s, v20.s[1]\n"
2194 "scvtf v0.4s, v0.4s, #0x4\n"
2195 "scvtf v18.4s, v18.4s, #0x4\n"
2196 "fmla v10.4s, v26.4s, v17.4s\n"
2197 "ldr q17, [x23, #0x10]\n"
2198 "fmul v26.4s, v27.4s, v20.s[2]\n"
2199 "fmul v20.4s, v27.4s, v20.s[3]\n"
2200 "fmla v12.4s, v0.4s, v26.4s\n"
2201 "ldr d0, [x22, #-0x8]\n"
2202 "ldr d26, [x21, #-0x8]\n"
2203 "fcvtl v0.4s, v0.4h\n"
2204 "fmla v28.4s, v18.4s, v20.4s\n"
2205 "movi v20.4s, #0x0\n"
2206 "movi v18.4s, #0x0\n"
2207 ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
2208 ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
2209 "ldr q19, [x23, #0x20]\n"
2210 "fcvtl v26.4s, v26.4h\n"
2211 ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
2212 ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
2213 "ldr q19, [x23, #0x40]\n"
2214 ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
2215 ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
2216 "ldr q19, [x23, #0x60]\n"
2217 ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
2218 ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
2219 "uzp1 v19.2d, v20.2d, v18.2d\n"
2220 "scvtf v19.4s, v19.4s, #0x4\n"
2221 "uzp2 v20.2d, v20.2d, v18.2d\n"
2222 "fmul v18.4s, v27.4s, v9.s[0]\n"
2223 "scvtf v20.4s, v20.4s, #0x4\n"
2224 "fmla v11.4s, v19.4s, v18.4s\n"
2225 "ldr q18, [x22, #0x0]\n"
2226 "fmul v19.4s, v27.4s, v9.s[1]\n"
2227 "fmla v13.4s, v20.4s, v19.4s\n"
2228 "movi v19.4s, #0x0\n"
2229 "movi v20.4s, #0x0\n"
2230 ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
2231 ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
2232 "ldr q17, [x23, #0x30]\n"
2233 ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
2234 ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
2235 "ldr q17, [x23, #0x50]\n"
2236 ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
2237 ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
2238 "ldr q17, [x23, #0x70]\n"
2239 "add x23, x23, #0x88\n"
2240 ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
2241 ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
2242 "uzp1 v17.2d, v19.2d, v20.2d\n"
2243 "scvtf v17.4s, v17.4s, #0x4\n"
2244 "uzp2 v20.2d, v19.2d, v20.2d\n"
2245 "fmul v19.4s, v27.4s, v9.s[2]\n"
2246 "fmul v9.4s, v27.4s, v9.s[3]\n"
2247 "scvtf v20.4s, v20.4s, #0x4\n"
2248 "fmla v22.4s, v17.4s, v19.4s\n"
2249 "ldr q17, [x22, #0x10]\n"
2250 "movi v19.4s, #0x0\n"
2251 ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
2252 "fmla v23.4s, v20.4s, v9.4s\n"
2253 "movi v20.4s, #0x0\n"
2254 "movi v9.4s, #0x0\n"
2255 ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
2256 "ldr q18, [x22, #0x20]\n"
2257 ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
2258 ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
2259 ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
2260 "ldr q18, [x22, #0x40]\n"
2261 ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
2262 ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
2263 "ldr q18, [x22, #0x60]\n"
2264 ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
2265 ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
2266 "movi v18.4s, #0x0\n"
2267 ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
2268 "ldr q17, [x22, #0x30]\n"
2269 ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
2270 ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
2271 "ldr q17, [x22, #0x50]\n"
2272 ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
2273 ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
2274 "ldr q17, [x22, #0x70]\n"
2275 "add x22, x22, #0x88\n"
2276 ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
2277 ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
2278 "uzp1 v17.2d, v19.2d, v20.2d\n"
2279 "uzp2 v20.2d, v19.2d, v20.2d\n"
2280 "fmul v19.4s, v27.4s, v0.s[0]\n"
2281 "scvtf v17.4s, v17.4s, #0x4\n"
2282 "scvtf v20.4s, v20.4s, #0x4\n"
2283 "fmla v25.4s, v17.4s, v19.4s\n"
2284 "ldr q19, [x21, #0x0]\n"
2285 "fmul v17.4s, v27.4s, v0.s[1]\n"
2286 "fmla v5.4s, v20.4s, v17.4s\n"
2287 "ldr q17, [x21, #0x10]\n"
2288 "uzp1 v20.2d, v9.2d, v18.2d\n"
2289 "uzp2 v9.2d, v9.2d, v18.2d\n"
2290 "fmul v18.4s, v27.4s, v0.s[2]\n"
2291 "fmul v0.4s, v27.4s, v0.s[3]\n"
2292 "scvtf v20.4s, v20.4s, #0x4\n"
2293 "scvtf v9.4s, v9.4s, #0x4\n"
2294 "fmla v7.4s, v20.4s, v18.4s\n"
2295 "movi v20.4s, #0x0\n"
2296 "movi v18.4s, #0x0\n"
2297 ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
2298 ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
2299 "ldr q19, [x21, #0x20]\n"
2300 "fmla v4.4s, v9.4s, v0.4s\n"
2301 "movi v9.4s, #0x0\n"
2302 "movi v0.4s, #0x0\n"
2303 ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
2304 "fmul v8.4s, v27.4s, v26.s[0]\n"
2305 ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
2306 "ldr q17, [x21, #0x30]\n"
2307 ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
2308 "fmul v31.4s, v27.4s, v26.s[1]\n"
2309 ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
2310 "ldr q19, [x21, #0x40]\n"
2311 ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
2312 "fmul v15.4s, v27.4s, v26.s[2]\n"
2313 "fmul v27.4s, v27.4s, v26.s[3]\n"
2314 ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
2315 "ldr q1, [x21, #0x50]\n"
2316 ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
2317 ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
2318 "ldr q26, [x21, #0x60]\n"
2319 ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
2320 ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
2321 "ldr q21, [x21, #0x70]\n"
2322 "add x21, x21, #0x88\n"
2323 ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
2324 ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
2325 ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
2326 ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
2327 "uzp1 v29.2d, v20.2d, v18.2d\n"
2328 "uzp2 v21.2d, v20.2d, v18.2d\n"
2329 "scvtf v29.4s, v29.4s, #0x4\n"
2330 "uzp1 v18.2d, v9.2d, v0.2d\n"
2331 "uzp2 v16.2d, v9.2d, v0.2d\n"
2332 "scvtf v21.4s, v21.4s, #0x4\n"
2333 "fmla v6.4s, v29.4s, v8.4s\n"
2334 "scvtf v18.4s, v18.4s, #0x4\n"
2335 "scvtf v16.4s, v16.4s, #0x4\n"
2336 "fmla v30.4s, v21.4s, v31.4s\n"
2337 "fmla v24.4s, v18.4s, v15.4s\n"
2338 "fmla v14.4s, v16.4s, v27.4s\n"
2339 "bgt 3b\n"
2340 "mov x20, %x[res_ptr]\n"
2341 "subs x27, x27, #0x4\n"
2342 "add %x[res_ptr], %x[res_ptr], #0x10\n"
2343 "str q2, [x20, #0x0]\n"
2344 "add x20, x20, %x[res_stride]\n"
2345 "str q10, [x20, #0x0]\n"
2346 "add x20, x20, %x[res_stride]\n"
2347 "str q12, [x20, #0x0]\n"
2348 "add x20, x20, %x[res_stride]\n"
2349 "str q28, [x20, #0x0]\n"
2350 "add x20, x20, %x[res_stride]\n"
2351 "str q11, [x20, #0x0]\n"
2352 "add x20, x20, %x[res_stride]\n"
2353 "str q13, [x20, #0x0]\n"
2354 "add x20, x20, %x[res_stride]\n"
2355 "str q22, [x20, #0x0]\n"
2356 "add x20, x20, %x[res_stride]\n"
2357 "str q23, [x20, #0x0]\n"
2358 "add x20, x20, %x[res_stride]\n"
2359 "str q25, [x20, #0x0]\n"
2360 "add x20, x20, %x[res_stride]\n"
2361 "str q5, [x20, #0x0]\n"
2362 "add x20, x20, %x[res_stride]\n"
2363 "str q7, [x20, #0x0]\n"
2364 "add x20, x20, %x[res_stride]\n"
2365 "str q4, [x20, #0x0]\n"
2366 "add x20, x20, %x[res_stride]\n"
2367 "str q6, [x20, #0x0]\n"
2368 "add x20, x20, %x[res_stride]\n"
2369 "str q30, [x20, #0x0]\n"
2370 "add x20, x20, %x[res_stride]\n"
2371 "str q24, [x20, #0x0]\n"
2372 "add x20, x20, %x[res_stride]\n"
2373 "str q14, [x20, #0x0]\n"
2374 "bne 2b\n"
2375 "mov x20, #0x4\n"
2376 "sub x10, x10, #0x10\n"
2377 "cmp x10, #0x10\n"
2378 "mov %x[res_ptr], x26\n"
2379 "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
2380 "bge 1b\n"
2381 "4:" // Row loop skip
2382 "cbz x10, 9f\n"
2383 "5:" // Row tail: Row loop
2384 "add x24, %x[b_ptr], #0x8\n"
2385 "mov x23, %x[nc]\n"
2386 "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
2387 "6:" // Row tail: Column loop
2388 "movi v2.16b, #0x0\n"
2389 "movi v10.16b, #0x0\n"
2390 "add x25, %x[a_ptr], #0x8\n"
2391 "mov x21, %x[nb]\n"
2392 "movi v12.16b, #0x0\n"
2393 "movi v28.16b, #0x0\n"
2394 "7:" // Row tail: Block loop
2395 "ldr q6, [x24, #0x0]\n"
2396 "ldr q5, [x24, #0x10]\n"
2397 "movi v17.16b, #0x4\n"
2398 "movi v8.4s, #0x0\n"
2399 "ldr q4, [x25, #0x0]\n"
2400 "ldr q13, [x25, #0x10]\n"
2401 "movi v27.4s, #0x0\n"
2402 "movi v0.4s, #0x0\n"
2403 "ldr q31, [x24, #0x20]\n"
2404 "ldr q14, [x24, #0x30]\n"
2405 "movi v29.4s, #0x0\n"
2406 "movi v22.16b, #0xf0\n"
2407 "ldr q11, [x25, #0x20]\n"
2408 "ldr q23, [x25, #0x30]\n"
2409 "sshl v21.16b, v6.16b, v17.16b\n"
2410 "sshl v16.16b, v5.16b, v17.16b\n"
2411 "ldr q20, [x25, #0x40]\n"
2412 "ldr q26, [x25, #0x50]\n"
2413 "and v6.16b, v6.16b, v22.16b\n"
2414 "and v5.16b, v5.16b, v22.16b\n"
2415 "ldr q25, [x25, #0x60]\n"
2416 "ldr q3, [x25, #0x70]\n"
2417 "sshl v19.16b, v31.16b, v17.16b\n"
2418 "sshl v18.16b, v14.16b, v17.16b\n"
2419 "ldr d17, [x25, #-0x8]\n"
2420 ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
2421 ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
2422 "and v31.16b, v31.16b, v22.16b\n"
2423 ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
2424 ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
2425 "and v14.16b, v14.16b, v22.16b\n"
2426 "sub x20, x24, #0x8\n"
2427 "ldr d16, [x20, #0x0]\n"
2428 "subs x21, x21, #0x1\n"
2429 "add x25, x25, #0x88\n"
2430 "fcvtl v17.4s, v17.4h\n"
2431 "add x24, x24, #0x48\n"
2432 ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
2433 ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
2434 ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
2435 ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
2436 "fcvtl v16.4s, v16.4h\n"
2437 ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
2438 ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
2439 "fmul v23.4s, v16.4s, v17.s[0]\n"
2440 "fmul v21.4s, v16.4s, v17.s[1]\n"
2441 "fmul v1.4s, v16.4s, v17.s[2]\n"
2442 "fmul v20.4s, v16.4s, v17.s[3]\n"
2443 ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
2444 ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
2445 ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
2446 ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
2447 ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
2448 ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
2449 "uzp1 v19.2d, v8.2d, v27.2d\n"
2450 "uzp2 v18.2d, v8.2d, v27.2d\n"
2451 "scvtf v19.4s, v19.4s, #0x4\n"
2452 "uzp1 v17.2d, v0.2d, v29.2d\n"
2453 "uzp2 v16.2d, v0.2d, v29.2d\n"
2454 "scvtf v18.4s, v18.4s, #0x4\n"
2455 "fmla v2.4s, v19.4s, v23.4s\n"
2456 "scvtf v17.4s, v17.4s, #0x4\n"
2457 "scvtf v16.4s, v16.4s, #0x4\n"
2458 "fmla v10.4s, v18.4s, v21.4s\n"
2459 "fmla v12.4s, v17.4s, v1.4s\n"
2460 "fmla v28.4s, v16.4s, v20.4s\n"
2461 "bgt 7b\n"
2462 "mov x20, %x[res_ptr]\n"
2463 "cmp x10, #0x1\n"
2464 "str q2, [x20, #0x0]\n"
2465 "add x20, x20, %x[res_stride]\n"
2466 "ble 8f\n"
2467 "cmp x10, #0x2\n"
2468 "str q10, [x20, #0x0]\n"
2469 "add x20, x20, %x[res_stride]\n"
2470 "ble 8f\n"
2471 "cmp x10, #0x3\n"
2472 "str q12, [x20, #0x0]\n"
2473 "add x20, x20, %x[res_stride]\n"
2474 "ble 8f\n"
2475 "str q28, [x20, #0x0]\n"
2476 "8:" // Row tail: Accumulator store skip
2477 "subs x23, x23, #0x4\n"
2478 "add %x[res_ptr], %x[res_ptr], #0x10\n"
2479 "bne 6b\n"
2480 "subs x10, x10, #0x4\n"
2481 "add %x[a_ptr], %x[a_ptr], x9\n"
2482 "mov %x[res_ptr], x22\n"
2483 "bgt 5b\n"
2484 "9:" // Row tail: Row loop skip
2485 : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2486 : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2487 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
2488 );
2489 return;
2490#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2491 ggml_gemm_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2492}
2493
2494void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2495 const int qk = QK8_0;
2496 const int nb = n / qk;
2497 const int ncols_interleaved = 8;
2498 const int blocklen = 8;
2499
2500 assert (n % qk == 0);
2501 assert (nr % 4 == 0);
2502 assert (nc % ncols_interleaved == 0);
2503
2504 UNUSED(s);
2505 UNUSED(bs);
2506 UNUSED(vx);
2507 UNUSED(vy);
2508 UNUSED(nr);
2509 UNUSED(nc);
2510 UNUSED(nb);
2511 UNUSED(ncols_interleaved);
2512 UNUSED(blocklen);
2513
2514#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2515#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2516 if (ggml_cpu_get_sve_cnt() == QK8_0) {
2517 const void * b_ptr = vx;
2518 const void * a_ptr = vy;
2519 float * res_ptr = s;
2520 size_t res_stride = bs * sizeof(float);
2521
2522 __asm__ __volatile__(
2523 "mov x20, #0x4\n"
2524 "mov x13, %x[nr]\n"
2525 "mov z28.s, #-0x4\n"
2526 "mov x12, #0x88\n"
2527 "ptrue p1.b\n"
2528 "whilelt p0.s, XZR, x20\n"
2529 "cmp x13, #0x10\n"
2530 "mul x12, %x[nb], x12\n"
2531 "blt 4f\n"
2532 "1:" // Row loop
2533 "add x11, %x[b_ptr], #0x10\n"
2534 "mov x10, %x[nc]\n"
2535 "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
2536 "2:" // Column loop
2537 "add x28, %x[a_ptr], #0x8\n"
2538 "mov z24.b, #0x0\n"
2539 "mov z15.b, #0x0\n"
2540 "mov x27, %x[nb]\n"
2541 "add x26, x28, x12\n"
2542 "mov z12.b, #0x0\n"
2543 "mov z0.b, #0x0\n"
2544 "add x25, x26, x12\n"
2545 "mov z13.b, #0x0\n"
2546 "mov z1.b, #0x0\n"
2547 "add x24, x25, x12\n"
2548 "mov z20.b, #0x0\n"
2549 "mov z25.b, #0x0\n"
2550 "mov z11.b, #0x0\n"
2551 "mov z16.b, #0x0\n"
2552 "mov z19.b, #0x0\n"
2553 "mov z26.b, #0x0\n"
2554 "mov z8.b, #0x0\n"
2555 "mov z29.b, #0x0\n"
2556 "mov z27.b, #0x0\n"
2557 "mov z10.b, #0x0\n"
2558 "3:" // Block loop
2559 "ld1b { z30.b }, p1/Z, [x11]\n"
2560 "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
2561 "mov z18.s, #0x0\n"
2562 "mov z7.s, #0x0\n"
2563 "ld1rqb { z3.b }, p1/Z, [x28]\n"
2564 "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
2565 "mov z9.s, #0x0\n"
2566 "mov z22.s, #0x0\n"
2567 "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
2568 "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
2569 "sub x20, x11, #0x10\n"
2570 "sub x23, x28, #0x8\n"
2571 "lsl z31.b, z30.b, #0x4\n"
2572 "lsl z6.b, z21.b, #0x4\n"
2573 "ld1h { z23.s }, p1/Z, [x20]\n"
2574 "sub x22, x26, #0x8\n"
2575 "and z30.b, z30.b, #0xf0\n"
2576 "and z21.b, z21.b, #0xf0\n"
2577 "sub x21, x25, #0x8\n"
2578 "sub x20, x24, #0x8\n"
2579 "lsl z14.b, z4.b, #0x4\n"
2580 "lsl z2.b, z17.b, #0x4\n"
2581 "subs x27, x27, #0x1\n"
2582 "add x11, x11, #0x90\n"
2583 ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n"
2584 ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n"
2585 "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
2586 "and z4.b, z4.b, #0xf0\n"
2587 ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n"
2588 ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n"
2589 "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
2590 "and z17.b, z17.b, #0xf0\n"
2591 "fcvt z23.s, p1/m, z23.h\n"
2592 ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n"
2593 ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n"
2594 "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
2595 ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n"
2596 ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n"
2597 "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
2598 "fscale z23.s, p1/m, z23.s, z28.s\n"
2599 ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n"
2600 ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n"
2601 "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
2602 ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n"
2603 ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n"
2604 "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
2605 "add x28, x28, #0x88\n"
2606 ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n"
2607 ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n"
2608 "ld1h { z3.s }, p0/Z, [x23]\n"
2609 ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n"
2610 ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n"
2611 "fcvt z3.s, p1/m, z3.h\n"
2612 "uzp1 z5.d, z18.d, z7.d\n"
2613 "uzp2 z18.d, z18.d, z7.d\n"
2614 "mov z3.q, z3.q[0]\n"
2615 "uzp1 z7.d, z9.d, z22.d\n"
2616 "uzp2 z22.d, z9.d, z22.d\n"
2617 "fmul z9.s, z23.s, z3.s[0]\n"
2618 "scvtf z5.s, p1/m, z5.s\n"
2619 "scvtf z18.s, p1/m, z18.s\n"
2620 "scvtf z7.s, p1/m, z7.s\n"
2621 "scvtf z22.s, p1/m, z22.s\n"
2622 "fmla z24.s, p1/M, z5.s, z9.s\n"
2623 "ld1rqb { z5.b }, p1/Z, [x26]\n"
2624 "fmul z9.s, z23.s, z3.s[1]\n"
2625 "fmla z15.s, p1/M, z18.s, z9.s\n"
2626 "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
2627 "fmul z9.s, z23.s, z3.s[2]\n"
2628 "fmul z3.s, z23.s, z3.s[3]\n"
2629 "fmla z12.s, p1/M, z7.s, z9.s\n"
2630 "mov z9.s, #0x0\n"
2631 "ld1h { z7.s }, p0/Z, [x22]\n"
2632 ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n"
2633 "fmla z0.s, p1/M, z22.s, z3.s\n"
2634 "mov z22.s, #0x0\n"
2635 "ld1h { z3.s }, p0/Z, [x21]\n"
2636 ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n"
2637 "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
2638 "fcvt z7.s, p1/m, z7.h\n"
2639 "fcvt z3.s, p1/m, z3.h\n"
2640 ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n"
2641 ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n"
2642 "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
2643 "mov z7.q, z7.q[0]\n"
2644 "mov z3.q, z3.q[0]\n"
2645 ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n"
2646 ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n"
2647 "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
2648 ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n"
2649 ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n"
2650 "uzp1 z5.d, z9.d, z22.d\n"
2651 "scvtf z5.s, p1/m, z5.s\n"
2652 "uzp2 z22.d, z9.d, z22.d\n"
2653 "fmul z9.s, z23.s, z7.s[0]\n"
2654 "scvtf z22.s, p1/m, z22.s\n"
2655 "fmla z13.s, p1/M, z5.s, z9.s\n"
2656 "ld1rqb { z9.b }, p1/Z, [x25]\n"
2657 "fmul z5.s, z23.s, z7.s[1]\n"
2658 "fmla z1.s, p1/M, z22.s, z5.s\n"
2659 "mov z5.s, #0x0\n"
2660 "mov z22.s, #0x0\n"
2661 ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n"
2662 ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n"
2663 "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
2664 ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n"
2665 ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n"
2666 "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
2667 ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n"
2668 ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n"
2669 "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
2670 "add x26, x26, #0x88\n"
2671 ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n"
2672 ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n"
2673 "uzp1 z18.d, z5.d, z22.d\n"
2674 "scvtf z18.s, p1/m, z18.s\n"
2675 "uzp2 z22.d, z5.d, z22.d\n"
2676 "fmul z5.s, z23.s, z7.s[2]\n"
2677 "fmul z7.s, z23.s, z7.s[3]\n"
2678 "scvtf z22.s, p1/m, z22.s\n"
2679 "fmla z20.s, p1/M, z18.s, z5.s\n"
2680 "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
2681 "ld1h { z5.s }, p0/Z, [x20]\n"
2682 "fcvt z5.s, p1/m, z5.h\n"
2683 "fmla z25.s, p1/M, z22.s, z7.s\n"
2684 "mov z22.s, #0x0\n"
2685 "mov z7.s, #0x0\n"
2686 ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n"
2687 ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n"
2688 "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
2689 "mov z5.q, z5.q[0]\n"
2690 ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n"
2691 ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n"
2692 "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
2693 ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n"
2694 ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n"
2695 "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
2696 ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n"
2697 ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n"
2698 "uzp1 z9.d, z22.d, z7.d\n"
2699 "scvtf z9.s, p1/m, z9.s\n"
2700 "uzp2 z22.d, z22.d, z7.d\n"
2701 "fmul z7.s, z23.s, z3.s[0]\n"
2702 "scvtf z22.s, p1/m, z22.s\n"
2703 "fmla z11.s, p1/M, z9.s, z7.s\n"
2704 "ld1rqb { z9.b }, p1/Z, [x24]\n"
2705 "fmul z7.s, z23.s, z3.s[1]\n"
2706 "fmla z16.s, p1/M, z22.s, z7.s\n"
2707 "mov z22.s, #0x0\n"
2708 "mov z7.s, #0x0\n"
2709 ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n"
2710 ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n"
2711 "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
2712 ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n"
2713 ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n"
2714 "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
2715 ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n"
2716 ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n"
2717 "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
2718 "add x25, x25, #0x88\n"
2719 ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n"
2720 ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n"
2721 "uzp1 z18.d, z22.d, z7.d\n"
2722 "scvtf z18.s, p1/m, z18.s\n"
2723 "uzp2 z7.d, z22.d, z7.d\n"
2724 "fmul z22.s, z23.s, z3.s[2]\n"
2725 "fmul z3.s, z23.s, z3.s[3]\n"
2726 "scvtf z7.s, p1/m, z7.s\n"
2727 "fmla z19.s, p1/M, z18.s, z22.s\n"
2728 "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
2729 "fmul z22.s, z23.s, z5.s[0]\n"
2730 "fmla z26.s, p1/M, z7.s, z3.s\n"
2731 "mov z3.s, #0x0\n"
2732 "mov z7.s, #0x0\n"
2733 ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n"
2734 ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n"
2735 "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
2736 ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n"
2737 ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n"
2738 "mov z9.s, #0x0\n"
2739 ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n"
2740 "mov z31.s, #0x0\n"
2741 ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n"
2742 "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
2743 "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
2744 ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n"
2745 "fmul z14.s, z23.s, z5.s[1]\n"
2746 ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n"
2747 "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
2748 "fmul z2.s, z23.s, z5.s[2]\n"
2749 "fmul z23.s, z23.s, z5.s[3]\n"
2750 ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n"
2751 ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n"
2752 "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
2753 ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n"
2754 ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n"
2755 "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
2756 "add x24, x24, #0x88\n"
2757 ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n"
2758 ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n"
2759 ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n"
2760 ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n"
2761 "uzp1 z18.d, z3.d, z7.d\n"
2762 "uzp2 z5.d, z3.d, z7.d\n"
2763 "scvtf z18.s, p1/m, z18.s\n"
2764 "uzp1 z6.d, z9.d, z31.d\n"
2765 "uzp2 z9.d, z9.d, z31.d\n"
2766 "scvtf z5.s, p1/m, z5.s\n"
2767 "fmla z8.s, p1/M, z18.s, z22.s\n"
2768 "scvtf z6.s, p1/m, z6.s\n"
2769 "scvtf z9.s, p1/m, z9.s\n"
2770 "fmla z29.s, p1/M, z5.s, z14.s\n"
2771 "fmla z27.s, p1/M, z6.s, z2.s\n"
2772 "fmla z10.s, p1/M, z9.s, z23.s\n"
2773 "bgt 3b\n"
2774 "mov x20, %x[res_ptr]\n"
2775 "subs x10, x10, #0x8\n"
2776 "add %x[res_ptr], %x[res_ptr], #0x20\n"
2777 "st1w { z24.s }, p1, [x20]\n"
2778 "add x20, x20, %x[res_stride]\n"
2779 "st1w { z15.s }, p1, [x20]\n"
2780 "add x20, x20, %x[res_stride]\n"
2781 "st1w { z12.s }, p1, [x20]\n"
2782 "add x20, x20, %x[res_stride]\n"
2783 "st1w { z0.s }, p1, [x20]\n"
2784 "add x20, x20, %x[res_stride]\n"
2785 "st1w { z13.s }, p1, [x20]\n"
2786 "add x20, x20, %x[res_stride]\n"
2787 "st1w { z1.s }, p1, [x20]\n"
2788 "add x20, x20, %x[res_stride]\n"
2789 "st1w { z20.s }, p1, [x20]\n"
2790 "add x20, x20, %x[res_stride]\n"
2791 "st1w { z25.s }, p1, [x20]\n"
2792 "add x20, x20, %x[res_stride]\n"
2793 "st1w { z11.s }, p1, [x20]\n"
2794 "add x20, x20, %x[res_stride]\n"
2795 "st1w { z16.s }, p1, [x20]\n"
2796 "add x20, x20, %x[res_stride]\n"
2797 "st1w { z19.s }, p1, [x20]\n"
2798 "add x20, x20, %x[res_stride]\n"
2799 "st1w { z26.s }, p1, [x20]\n"
2800 "add x20, x20, %x[res_stride]\n"
2801 "st1w { z8.s }, p1, [x20]\n"
2802 "add x20, x20, %x[res_stride]\n"
2803 "st1w { z29.s }, p1, [x20]\n"
2804 "add x20, x20, %x[res_stride]\n"
2805 "st1w { z27.s }, p1, [x20]\n"
2806 "add x20, x20, %x[res_stride]\n"
2807 "st1w { z10.s }, p1, [x20]\n"
2808 "bne 2b\n"
2809 "mov x20, #0x4\n"
2810 "sub x13, x13, #0x10\n"
2811 "cmp x13, #0x10\n"
2812 "mov %x[res_ptr], x9\n"
2813 "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
2814 "bge 1b\n"
2815 "4:" // Row loop skip
2816 "cbz x13, 9f\n"
2817 "5:" // Row tail: Row loop
2818 "add x25, %x[b_ptr], #0x10\n"
2819 "mov x24, %x[nc]\n"
2820 "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
2821 "6:" // Row tail: Column loop
2822 "mov z24.b, #0x0\n"
2823 "mov z15.b, #0x0\n"
2824 "add x28, %x[a_ptr], #0x8\n"
2825 "mov x22, %x[nb]\n"
2826 "mov z12.b, #0x0\n"
2827 "mov z0.b, #0x0\n"
2828 "7:" // Row tail: Block loop
2829 "ld1b { z3.b }, p1/Z, [x25]\n"
2830 "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
2831 "mov z2.s, #0x0\n"
2832 "mov z25.s, #0x0\n"
2833 "ld1rqb { z26.b }, p1/Z, [x28]\n"
2834 "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
2835 "mov z27.s, #0x0\n"
2836 "mov z19.s, #0x0\n"
2837 "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
2838 "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
2839 "sub x21, x25, #0x10\n"
2840 "sub x20, x28, #0x8\n"
2841 "lsl z20.b, z3.b, #0x4\n"
2842 "lsl z4.b, z6.b, #0x4\n"
2843 "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
2844 "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
2845 "and z3.b, z3.b, #0xf0\n"
2846 "and z6.b, z6.b, #0xf0\n"
2847 "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
2848 "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
2849 "lsl z8.b, z29.b, #0x4\n"
2850 "lsl z14.b, z16.b, #0x4\n"
2851 "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
2852 "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
2853 ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n"
2854 ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n"
2855 "and z29.b, z29.b, #0xf0\n"
2856 "ld1h { z17.s }, p1/Z, [x21]\n"
2857 ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n"
2858 ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n"
2859 "and z16.b, z16.b, #0xf0\n"
2860 "ld1h { z4.s }, p0/Z, [x20]\n"
2861 "subs x22, x22, #0x1\n"
2862 "add x28, x28, #0x88\n"
2863 "fcvt z17.s, p1/m, z17.h\n"
2864 "add x25, x25, #0x90\n"
2865 ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n"
2866 ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n"
2867 "fcvt z4.s, p1/m, z4.h\n"
2868 ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n"
2869 ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n"
2870 "fscale z17.s, p1/m, z17.s, z28.s\n"
2871 "mov z4.q, z4.q[0]\n"
2872 ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n"
2873 ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n"
2874 "fmul z23.s, z17.s, z4.s[0]\n"
2875 "fmul z9.s, z17.s, z4.s[1]\n"
2876 "fmul z21.s, z17.s, z4.s[2]\n"
2877 "fmul z4.s, z17.s, z4.s[3]\n"
2878 ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n"
2879 ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n"
2880 ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n"
2881 ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n"
2882 ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n"
2883 ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n"
2884 "uzp1 z31.d, z2.d, z25.d\n"
2885 "uzp2 z13.d, z2.d, z25.d\n"
2886 "scvtf z31.s, p1/m, z31.s\n"
2887 "uzp1 z17.d, z27.d, z19.d\n"
2888 "uzp2 z18.d, z27.d, z19.d\n"
2889 "scvtf z13.s, p1/m, z13.s\n"
2890 "fmla z24.s, p1/M, z31.s, z23.s\n"
2891 "scvtf z17.s, p1/m, z17.s\n"
2892 "scvtf z18.s, p1/m, z18.s\n"
2893 "fmla z15.s, p1/M, z13.s, z9.s\n"
2894 "fmla z12.s, p1/M, z17.s, z21.s\n"
2895 "fmla z0.s, p1/M, z18.s, z4.s\n"
2896 "bgt 7b\n"
2897 "mov x20, %x[res_ptr]\n"
2898 "cmp x13, #0x1\n"
2899 "st1w { z24.s }, p1, [x20]\n"
2900 "add x20, x20, %x[res_stride]\n"
2901 "ble 8f\n"
2902 "cmp x13, #0x2\n"
2903 "st1w { z15.s }, p1, [x20]\n"
2904 "add x20, x20, %x[res_stride]\n"
2905 "ble 8f\n"
2906 "cmp x13, #0x3\n"
2907 "st1w { z12.s }, p1, [x20]\n"
2908 "add x20, x20, %x[res_stride]\n"
2909 "ble 8f\n"
2910 "st1w { z0.s }, p1, [x20]\n"
2911 "8:" // Row tail: Accumulator store skip
2912 "subs x24, x24, #0x8\n"
2913 "add %x[res_ptr], %x[res_ptr], #0x20\n"
2914 "bne 6b\n"
2915 "subs x13, x13, #0x4\n"
2916 "add %x[a_ptr], %x[a_ptr], x12\n"
2917 "mov %x[res_ptr], x23\n"
2918 "bgt 5b\n"
2919 "9:" // Row tail: Row loop skip
2920 : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2921 : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2922 : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
2923 );
2924 return;
2925 }
2926#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2927
2928#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2929 ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2930}
2931
2932void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2933 const int qk = QK8_0;
2934 const int nb = n / qk;
2935 const int ncols_interleaved = 4;
2936 const int blocklen = 4;
2937
2938 assert (n % qk == 0);
2939 assert (nr % 4 == 0);
2940 assert (nc % ncols_interleaved == 0);
2941
2942 UNUSED(s);
2943 UNUSED(bs);
2944 UNUSED(vx);
2945 UNUSED(vy);
2946 UNUSED(nr);
2947 UNUSED(nc);
2948 UNUSED(nb);
2949 UNUSED(ncols_interleaved);
2950 UNUSED(blocklen);
2951
2952#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2953 const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
2954
2955 for (int y = 0; y < nr / 4; y++) {
2956 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2957 for (int x = 0; x < nc / ncols_interleaved; x++) {
2958 const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
2959
2960 float32x4_t sumf[4];
2961 for (int m = 0; m < 4; m++) {
2962 sumf[m] = vdupq_n_f32(0);
2963 }
2964
2965 for (int l = 0; l < nb; l++) {
2966 float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
2967 float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
2968
2969 int32x4_t sumi_0 = vdupq_n_s32(0);
2970 int32x4_t sumi_1 = vdupq_n_s32(0);
2971 int32x4_t sumi_2 = vdupq_n_s32(0);
2972 int32x4_t sumi_3 = vdupq_n_s32(0);
2973
2974 for (int k = 0; k < 4; k++) {
2975 int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
2976 int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
2977
2978 uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
2979 int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
2980 int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
2981
2982 sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
2983 sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
2984 sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
2985 sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
2986 sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
2987 sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
2988 sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
2989 sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
2990 }
2991
2992 sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
2993 sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
2994 sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
2995 sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
2996 }
2997
2998 for (int m = 0; m < 4; m++) {
2999 vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3000 }
3001 }
3002 }
3003 return;
3004#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3005 ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3006}
3007
3008void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3009 constexpr int qk = QK_K;
3010 const int nb = n / qk;
3011
3012 constexpr int ncols_interleaved = 8;
3013 constexpr int blocklen = 4;
3014
3015 assert(n % qk == 0);
3016 assert(nr % 4 == 0);
3017 assert(nc % ncols_interleaved == 0);
3018
3019 UNUSED(nb);
3020 UNUSED(ncols_interleaved);
3021 UNUSED(blocklen);
3022
3023#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3024 constexpr int q8_k_blocklen = 4;
3025 constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
3026 const uint8x16_t m4b = vdupq_n_u8(0x0f);
3027
3028 // 8 accumulators: 2 row pairs × 4 col pairs
3029 float32x4_t acc_f32[acc_size];
3030
3031 for (int y = 0; y < nr / q8_k_blocklen; y++) {
3032 const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3033
3034 for (int x = 0; x < nc / ncols_interleaved; x++) {
3035 const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3036
3037 for (int i = 0; i < acc_size; i++) {
3038 acc_f32[i] = vdupq_n_f32(0);
3039 }
3040
3041 for (int b = 0; b < nb; b++) {
3042 // d4 0 1 2 3, 4 5 6 7
3043 float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
3044 float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
3045 // d8 0 1 2 3
3046 float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3047 // mins
3048 float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
3049 float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
3050
3051 // Precomputation of scales and mins
3052 float32x4_t sbd_scale_0123[q8_k_blocklen];
3053 float32x4_t sbd_scale_4567[q8_k_blocklen];
3054 float32x4_t sbd_min_0123[q8_k_blocklen];
3055 float32x4_t sbd_min_4567[q8_k_blocklen];
3056
3057 sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
3058 sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
3059 sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
3060 sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
3061
3062 sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
3063 sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
3064 sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
3065 sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
3066
3067 sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
3068 sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
3069 sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
3070 sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
3071
3072 sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
3073 sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
3074 sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
3075 sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
3076
3077 // Precomputation of bsums, each vpaddq calcs all the bsums for each row
3078 const int16x8_t bsums[q8_k_blocklen] = {
3079 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3080 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3081 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3082 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3083 };
3084 int16_t bsums_arr[QK_K / 64][8];
3085 for (int q8_row = 0; q8_row < 4; q8_row++) {
3086 vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3087 }
3088
3089 // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
3090 int32x4_t bias_acc[acc_size];
3091 for (int i = 0; i < acc_size; i++) {
3092 bias_acc[i] = vdupq_n_s32(0);
3093 }
3094
3095 for (int sb = 0; sb < QK_K / 64; sb++) {
3096 // Int accumulators for qs vecdot (4 row x 2 col quartets)
3097 int32x4_t acc_lo[acc_size];
3098 int32x4_t acc_hi[acc_size];
3099 for (int i = 0; i < acc_size; i++) {
3100 acc_lo[i] = vdupq_n_s32(0);
3101 acc_hi[i] = vdupq_n_s32(0);
3102 }
3103 // Need scales for the low and high nibbles
3104 // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3105 int16x8_t q4sb_scales[2];
3106 int16x8_t q4sb_mins[2];
3107 for (int i = 0; i < 2; i++) {
3108 int8_t aux_q4sb[8];
3109 const int offset = sb * 24 + i * 12;
3110 decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
3111 q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
3112 }
3113
3114 constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
3115 for (int k = 0; k < reads_per_sb; k++) {
3116 const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
3117 const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
3118
3119 // 0..3 & 32..35
3120 const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
3121 const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
3122
3123 const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
3124 const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
3125
3126 acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
3127 acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
3128 acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
3129 acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
3130
3131 acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
3132 acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
3133 acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
3134 acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
3135
3136 const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
3137 const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
3138
3139 acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
3140 acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
3141 acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
3142 acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
3143
3144 acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
3145 acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
3146 acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
3147 acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
3148 }
3149
3150 // Scale and bias application
3151 // acc is stored interleaved to match output layout
3152 const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
3153 const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
3154 const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
3155 const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
3156 for (int row = 0; row < q8_k_blocklen; row++) {
3157 // Bias correction
3158 // row c0123 blk0 and blk1
3159 const float32x4_t sumf_0123 =
3160 vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
3161 vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
3162 acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
3163
3164 // row c4567 blk0 and blk1
3165 const float32x4_t sumf_4567 =
3166 vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
3167 vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
3168 acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
3169
3170 // Bias
3171 const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
3172 const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
3173
3174 // row c0123 blk0 and blk1
3175 bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
3176 bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
3177
3178 // row c4567 blk0 and blk1
3179 bias_acc[2 * row + 1] =
3180 vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
3181 bias_acc[2 * row + 1] =
3182 vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
3183 }
3184 } // for sb
3185
3186 for (int row = 0; row < q8_k_blocklen; row++) {
3187 acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
3188 acc_f32[2 * row + 1] =
3189 vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
3190 }
3191 } // for b
3192
3193 for (int i = 0; i < q8_k_blocklen; i++) {
3194 int row = y * q8_k_blocklen + i;
3195 for (int j = 0; j < 2; j++) {
3196 int col = x * ncols_interleaved + j * 4;
3197 int offset = row * bs + col;
3198 vst1q_f32(s + offset, acc_f32[2 * i + j]);
3199 }
3200 }
3201 } // for x
3202 } // for y
3203 return;
3204#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3205 ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3206}
3207
3208void ggml_gemm_q4_K_8x8_q8_K(int n,
3209 float * GGML_RESTRICT s,
3210 size_t bs,
3211 const void * GGML_RESTRICT vx,
3212 const void * GGML_RESTRICT vy,
3213 int nr,
3214 int nc) {
3215 constexpr int qk = QK_K;
3216 const int nb = n / qk;
3217
3218 constexpr int ncols_interleaved = 8;
3219 constexpr int blocklen = 8;
3220
3221 assert(n % qk == 0);
3222 assert(nr % 4 == 0);
3223 assert(nc % ncols_interleaved == 0);
3224
3225 UNUSED(nb);
3226 UNUSED(ncols_interleaved);
3227 UNUSED(blocklen);
3228
3229#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3230 constexpr int q8_k_blocklen = 4;
3231 const uint8x16_t m4b = vdupq_n_u8(0x0f);
3232
3233 // 8 accumulators: 2 row pairs × 4 col pairs
3234 float32x4_t acc_f32[blocklen];
3235
3236 for (int y = 0; y < nr / q8_k_blocklen; y++) {
3237 const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3238
3239 for (int x = 0; x < nc / ncols_interleaved; x++) {
3240 const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3241
3242 for (int i = 0; i < blocklen; i++) {
3243 acc_f32[i] = vdupq_n_f32(0);
3244 }
3245
3246 for (int b = 0; b < nb; b++) {
3247 // bsums pairs belongs to the same q8_k subblock
3248 const int16x8_t bsums[4]{
3249 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3250 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3251 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3252 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3253 };
3254 int16_t bsums_arr[4][8];
3255 for (int q8_row = 0; q8_row < 4; q8_row++) {
3256 vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3257 }
3258
3259 int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
3260 int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
3261 int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
3262 for (int i = 0; i < 8; i++) {
3263 acc[i] = vdupq_n_s32(0);
3264 bias_acc[i] = vdupq_n_s32(0);
3265 }
3266
3267 for (int sb = 0; sb < QK_K / 64; sb++) {
3268 // Need scales for the low and high nibbles
3269 // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3270 int8_t q4sb_scales[2][8];
3271 int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
3272 for (int i = 0; i < 2; i++) {
3273 const int offset = sb * 24 + i * 12;
3274 decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
3275 }
3276
3277 // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
3278 const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
3279
3280 int8x16_t q8_qs_01[8];
3281 int8x16_t q8_qs_23[8];
3282
3283 // Load 32-byte per row pair, 1 subblock each time
3284 for (int i = 0; i < 8; i++) {
3285 const int offset = i * 32; // 16 for row 01, 16 for row 23
3286 q8_qs_01[i] = vld1q_s8(q8_base + offset);
3287 q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
3288 }
3289
3290 const int8x16_t q8s[2][8] = {
3291 { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
3292 q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
3293 { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
3294 q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
3295 };
3296
3297 // Q4s columns iterated in pairs (01, 23, 45, 67)
3298 for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
3299 for (int i = 0; i < 4; i++) {
3300 sb_acc[i] = vdupq_n_s32(0);
3301 }
3302
3303 uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
3304 uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
3305 uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
3306 uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
3307 const int8x16_t q4_nibbles[2][4] = {
3308 {
3309 vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
3310 vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
3311 vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
3312 vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
3313 },
3314 {
3315 vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
3316 vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
3317 vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
3318 vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
3319 }
3320 };
3321
3322 // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
3323 // for each of the internal 32 qs subblock (blk)
3324 for (int rp = 0; rp < 2; rp++) {
3325 for (int blk = 0; blk < 2; blk++) {
3326 const int8x16_t * q8 = &q8s[rp][4 * blk];
3327 const int8x16_t * q4 = q4_nibbles[blk];
3328 int32x4_t acc = sb_acc[2 * rp + blk];
3329 // mul add for each qs in the same subblock
3330 for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
3331 acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
3332 }
3333 sb_acc[2 * rp + blk] = acc;
3334 }
3335 }
3336
3337 // Scales[i] corresponds to column i
3338 const int scale_offset = cp * 2;
3339 const int32_t scale_00 = q4sb_scales[0][scale_offset];
3340 const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
3341 const int32_t scale_10 = q4sb_scales[1][scale_offset];
3342 const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
3343 const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
3344 const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
3345
3346 acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
3347 acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
3348 acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
3349 acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
3350 }
3351
3352 // Multiply Acc bsum + mins
3353 for (int q8_row = 0; q8_row < 4; q8_row++) {
3354 // Each pair of subblocks share the same bsums
3355 // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
3356 int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
3357 int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
3358
3359 bias_acc[2 * q8_row] =
3360 vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
3361 bias_acc[2 * q8_row] =
3362 vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
3363 bias_acc[2 * q8_row + 1] =
3364 vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
3365 bias_acc[2 * q8_row + 1] =
3366 vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
3367 }
3368 } // for sb
3369
3370 // Reorder of i8mm output with bias and output layout
3371 for (int i = 0; i < 8; i++) {
3372 int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
3373 acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
3374 }
3375 int32x4_t reorder_acc[8] = {
3376 vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
3377 vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
3378 vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
3379 vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
3380 vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
3381 vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
3382 vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
3383 vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
3384 };
3385
3386 for (int i = 0; i < q8_k_blocklen; i++) {
3387 for (int j = 0; j < 2; j++) {
3388 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
3389 float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
3390 const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
3391
3392 float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
3393 const float32x4_t scale = vmulq_f32(q4_d, q8_d);
3394
3395 acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
3396 acc_f32[2 * i + j] =
3397 vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
3398 }
3399 }
3400 } // for b
3401
3402 // With the previous reorder, the tile is already in the correct memory layout.
3403 for (int i = 0; i < q8_k_blocklen; i++) {
3404 int row = y * q8_k_blocklen + i;
3405 for (int j = 0; j < 2; j++) {
3406 int col = x * ncols_interleaved + j * 4;
3407 int offset = row * bs + col;
3408 vst1q_f32(s + offset, acc_f32[2 * i + j]);
3409 }
3410 }
3411 } // for x
3412 } // for y
3413 return;
3414#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3415 ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3416}
3417
3418void ggml_gemm_q5_K_8x8_q8_K(int n,
3419 float * GGML_RESTRICT s,
3420 size_t bs,
3421 const void * GGML_RESTRICT vx,
3422 const void * GGML_RESTRICT vy,
3423 int nr,
3424 int nc) {
3425 constexpr int qk = QK_K;
3426 const int nb = n / qk;
3427
3428 constexpr int ncols_interleaved = 8;
3429 constexpr int blocklen = 8;
3430
3431 assert(n % qk == 0);
3432 assert(nr % 4 == 0);
3433 assert(nc % ncols_interleaved == 0);
3434
3435 UNUSED(nb);
3436 UNUSED(ncols_interleaved);
3437 UNUSED(blocklen);
3438
3439#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3440 constexpr int q8_k_blocklen = 4;
3441 constexpr int col_pairs = ncols_interleaved / 2;
3442 const uint8x16_t m4b = vdupq_n_u8(0x0f);
3443 const uint8x16_t mone = vdupq_n_u8(1);
3444 const uint8x16_t mtwo = vdupq_n_u8(2);
3445
3446 // 8 accumulators: 2 row pairs × 4 col pairs
3447 float32x4_t acc_f32[blocklen];
3448
3449 for (int y = 0; y < nr / q8_k_blocklen; y++) {
3450 const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3451
3452 for (int x = 0; x < nc / ncols_interleaved; x++) {
3453 const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
3454
3455 for (int i = 0; i < blocklen; i++) {
3456 acc_f32[i] = vdupq_n_f32(0);
3457 }
3458
3459 for (int b = 0; b < nb; b++) {
3460 // bsums pairs belongs to the same q8_k subblock
3461 const int16x8_t bsums[4]{
3462 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3463 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3464 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3465 vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3466 };
3467 int16_t bsums_arr[4][8];
3468 for (int q8_row = 0; q8_row < 4; q8_row++) {
3469 vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3470 }
3471
3472 int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
3473 int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
3474 int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
3475 for (int i = 0; i < 8; i++) {
3476 acc[i] = vdupq_n_s32(0);
3477 bias_acc[i] = vdupq_n_s32(0);
3478 }
3479
3480 // Load qh once per block and shift after each subblock
3481 const uint8_t * qh_base = q5_ptr[b].qh;
3482 uint8x16_t qh[col_pairs][4];
3483 for (int cp = 0; cp < col_pairs; cp++) {
3484 qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
3485 qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
3486 qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
3487 qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
3488 }
3489
3490 for (int sb = 0; sb < QK_K / 64; sb++) {
3491 // Need scales for the low and high nibbles
3492 // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3493 int8_t q5sb_scales[2][8];
3494 int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
3495 for (int i = 0; i < 2; i++) {
3496 const int offset = sb * 24 + i * 12;
3497 decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
3498 }
3499
3500 // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
3501 const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
3502
3503 int8x16_t q8_qs_01[8];
3504 int8x16_t q8_qs_23[8];
3505
3506 // Load 32-byte per row pair, 1 subblock each time
3507 for (int i = 0; i < 8; i++) {
3508 const int offset = i * 32; // 16 for row 01, 16 for row 23
3509 q8_qs_01[i] = vld1q_s8(q8_base + offset);
3510 q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
3511 }
3512
3513 const int8x16_t q8s[2][8] = {
3514 { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
3515 q8_qs_01[7] },
3516 { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
3517 q8_qs_23[7] },
3518 };
3519
3520 // Q5s columns iterated in pairs (01, 23, 45, 67)
3521 for (int cp = 0; cp < col_pairs; cp++) {
3522 for (int i = 0; i < 4; i++) {
3523 sb_acc[i] = vdupq_n_s32(0);
3524 }
3525
3526 uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
3527 uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
3528 uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
3529 uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
3530
3531 // This is the only part of the algorithm that differs with Q4_K
3532 // Extract High bits and pack into 5 bit weights
3533 uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
3534 uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
3535 qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
3536 // Same as Q4_K, i8mm to dequantize the weights.
3537 const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
3538 int32x4_t acc_0 = sb_acc[0];
3539 acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
3540 int32x4_t acc_2 = sb_acc[2];
3541 acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
3542 const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
3543 int32x4_t acc_1 = sb_acc[1];
3544 acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
3545 int32x4_t acc_3 = sb_acc[3];
3546 acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
3547
3548 // Repeat for the other 3 columns (8..15, 16..23, 24..31)
3549 uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
3550 uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
3551 qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
3552 const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
3553 acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
3554 acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
3555 const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
3556 acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
3557 acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
3558
3559 uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
3560 uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
3561 qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
3562 const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
3563 acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
3564 acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
3565 const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
3566 acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
3567 acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
3568
3569 uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
3570 uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
3571 qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
3572 const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
3573 acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
3574 sb_acc[0] = acc_0;
3575 acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
3576 sb_acc[2] = acc_2;
3577
3578 // Scales[i] corresponds to column i
3579 const int scale_offset = cp * 2;
3580 const int32_t s0 = q5sb_scales[0][scale_offset];
3581 const int32_t s1 = q5sb_scales[0][scale_offset + 1];
3582 const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
3583 acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
3584 acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
3585
3586 const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
3587 acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
3588 sb_acc[1] = acc_1;
3589 acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
3590 sb_acc[3] = acc_3;
3591
3592 const int32_t s2 = q5sb_scales[1][scale_offset];
3593 const int32_t s3 = q5sb_scales[1][scale_offset + 1];
3594 const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
3595 acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
3596 acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
3597 }
3598
3599 // Multiply Acc bsum + mins
3600 for (int q8_row = 0; q8_row < 4; q8_row++) {
3601 // Each pair of subblocks share the same bsums
3602 // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
3603 int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
3604 int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
3605
3606 bias_acc[2 * q8_row] =
3607 vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
3608 bias_acc[2 * q8_row] =
3609 vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
3610 bias_acc[2 * q8_row + 1] =
3611 vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
3612 bias_acc[2 * q8_row + 1] =
3613 vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
3614 }
3615 } // for sb
3616
3617 // Reorder of i8mm output with bias and output layout
3618 for (int i = 0; i < 8; i++) {
3619 int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
3620 acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
3621 }
3622 int32x4_t reorder_acc[8] = {
3623 vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
3624 vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
3625 vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
3626 vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
3627 vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
3628 vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
3629 vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
3630 vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
3631 };
3632
3633 for (int i = 0; i < q8_k_blocklen; i++) {
3634 for (int j = 0; j < 2; j++) {
3635 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
3636 float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
3637 const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
3638
3639 float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
3640 const float32x4_t scale = vmulq_f32(q5_d, q8_d);
3641
3642 acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
3643 acc_f32[2 * i + j] =
3644 vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
3645 }
3646 }
3647 } // for b
3648
3649 // With the previous reorder, the tile is already in the correct memory layout.
3650 for (int i = 0; i < q8_k_blocklen; i++) {
3651 int row = y * q8_k_blocklen + i;
3652 for (int j = 0; j < 2; j++) {
3653 int col = x * ncols_interleaved + j * 4;
3654 int offset = row * bs + col;
3655 vst1q_f32(s + offset, acc_f32[2 * i + j]);
3656 }
3657 }
3658 } // for x
3659 } // for y
3660 return;
3661#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3662 ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3663}
3664
3665void ggml_gemm_q6_K_8x4_q8_K(int n,
3666 float * GGML_RESTRICT s,
3667 size_t bs,
3668 const void * GGML_RESTRICT vx,
3669 const void * GGML_RESTRICT vy,
3670 int nr,
3671 int nc) {
3672 constexpr int qk = QK_K;
3673 const int nb = n / qk;
3674
3675 constexpr int ncols_interleaved = 8;
3676 constexpr int blocklen = 4;
3677
3678 assert(n % qk == 0);
3679 assert(nr % 4 == 0);
3680 assert(nc % ncols_interleaved == 0);
3681
3682 UNUSED(nb);
3683 UNUSED(ncols_interleaved);
3684 UNUSED(blocklen);
3685
3686#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3687 constexpr int q8_k_blocklen = 4;
3688 constexpr int col_groups = ncols_interleaved / 4;
3689 constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
3690 const uint8x16_t m4b = vdupq_n_u8(0x0f);
3691 const uint8x16_t mask_lo = vdupq_n_u8(0x03);
3692 const uint8x16_t mask_hi = vdupq_n_u8(0x30);
3693 const int8x16_t m32s = vdupq_n_s8(32);
3694
3695 float32x4_t acc_f32[acc_size];
3696
3697 for (int y = 0; y < nr / q8_k_blocklen; y++) {
3698 const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3699
3700 for (int x = 0; x < nc / ncols_interleaved; x++) {
3701 const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
3702
3703 for (int i = 0; i < acc_size; i++) {
3704 acc_f32[i] = vdupq_n_f32(0);
3705 }
3706
3707 for (int b = 0; b < nb; b++) {
3708 float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
3709 float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
3710 float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3711
3712 float32x4_t sbd_scale_0123[q8_k_blocklen];
3713 float32x4_t sbd_scale_4567[q8_k_blocklen];
3714
3715 sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
3716 sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
3717 sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
3718 sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
3719 sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
3720 sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
3721 sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
3722 sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
3723
3724 int32x4_t acc_s32[acc_size];
3725 for (int i = 0; i < acc_size; i++) {
3726 acc_s32[i] = vdupq_n_s32(0);
3727 }
3728
3729 int16_t q6_scales[8 * 16];
3730 for (int i = 0; i < 16; i++) {
3731 int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
3732 vst1q_s16(q6_scales + i * 8, scales);
3733 }
3734
3735 for (int half = 0; half < 2; half++) {
3736 const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
3737 const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
3738
3739 for (int sb = 0; sb < QK_K / 64; sb++) {
3740 int32x4_t acc_lo[acc_size];
3741 int32x4_t acc_hi[acc_size];
3742 for (int i = 0; i < acc_size; i++) {
3743 acc_lo[i] = vdupq_n_s32(0);
3744 acc_hi[i] = vdupq_n_s32(0);
3745 }
3746
3747 const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
3748 const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
3749
3750 // 4 rows * 16 elements per scale
3751 // 4 reads of 16 bytes each
3752 constexpr int reads_per_sb = 4;
3753 int8x16_t q8_l[reads_per_sb];
3754 int8x16_t q8_h[reads_per_sb];
3755 for (int k = 0; k < reads_per_sb; k++) {
3756 q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
3757 q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
3758 }
3759
3760 const int ql_off_base = sb * QK_K / 2;
3761 const int qh_off_base = ql_off_base & 255;
3762
3763 uint8x16_t q6_ql_0123[reads_per_sb];
3764 uint8x16_t q6_ql_4567[reads_per_sb];
3765 uint8x16_t q6_qh_0123[reads_per_sb];
3766 uint8x16_t q6_qh_4567[reads_per_sb];
3767
3768 for (int k = 0; k < reads_per_sb; k++) {
3769 q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
3770 q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
3771 q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
3772 q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
3773 }
3774
3775 if (sb > 1) {
3776 for (int k = 0; k < reads_per_sb; k++) {
3777 q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
3778 q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
3779 }
3780 }
3781
3782 for (int k = 0; k < reads_per_sb; k++) {
3783 // q = (ql | qh) - 32
3784 const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
3785 const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
3786 const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
3787 const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
3788
3789 const int8x16_t q6_0123_lo = vsubq_s8(
3790 vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
3791 const int8x16_t q6_0123_hi = vsubq_s8(
3792 vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
3793
3794 acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
3795 acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
3796 acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
3797 acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
3798
3799 acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
3800 acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
3801 acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
3802 acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
3803
3804 const int8x16_t q6_4567_lo = vsubq_s8(
3805 vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
3806 const int8x16_t q6_4567_hi = vsubq_s8(
3807 vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
3808
3809 acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
3810 acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
3811 acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
3812 acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
3813
3814 acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
3815 acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
3816 acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
3817 acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
3818 }
3819
3820 // Scale and bias
3821 const int scale_idx_l = half * 8 + sb;
3822 const int scale_idx_h = half * 8 + sb + 4;
3823
3824 for (int g = 0; g < col_groups; g++) {
3825 const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
3826 const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
3827 const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
3828 const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
3829 const int acc_offset = g * q8_k_blocklen;
3830
3831 for (int row = 0; row < q8_k_blocklen; row++) {
3832 const int idx = row * 2 + g;
3833 acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
3834 acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
3835 }
3836 }
3837 }
3838 }
3839
3840 // Finally we apply the superblock scales
3841 for (int row = 0; row < q8_k_blocklen; row++) {
3842 const int idx0 = 2 * row;
3843 const int idx1 = 2 * row + 1;
3844 const int32x4_t acc_0123 = acc_s32[idx0];
3845 const int32x4_t acc_4567 = acc_s32[idx1];
3846
3847 acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
3848 acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
3849 }
3850 } // for b
3851
3852 for (int i = 0; i < q8_k_blocklen; i++) {
3853 int row = y * q8_k_blocklen + i;
3854 for (int j = 0; j < 2; j++) {
3855 int col = x * ncols_interleaved + j * 4;
3856 int offset = row * bs + col;
3857 vst1q_f32(s + offset, acc_f32[2 * i + j]);
3858 }
3859 }
3860 } // for x
3861 } // for y
3862 return;
3863#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3864 ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3865}
3866
3867void ggml_gemm_q6_K_8x8_q8_K(int n,
3868 float * GGML_RESTRICT s,
3869 size_t bs,
3870 const void * GGML_RESTRICT vx,
3871 const void * GGML_RESTRICT vy,
3872 int nr,
3873 int nc) {
3874 constexpr int qk = QK_K;
3875 const int nb = n / qk;
3876
3877 constexpr int ncols_interleaved = 8;
3878 constexpr int blocklen = 8;
3879
3880 assert(n % qk == 0);
3881 assert(nr % 4 == 0);
3882 assert(nc % ncols_interleaved == 0);
3883
3884 UNUSED(nb);
3885 UNUSED(ncols_interleaved);
3886 UNUSED(blocklen);
3887
3888#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3889 constexpr int q8_k_blocklen = 4;
3890 const uint8x16_t m4b = vdupq_n_u8(0x0f);
3891 const uint8x16_t mask_lo = vdupq_n_u8(0x03);
3892 const uint8x16_t mask_hi = vdupq_n_u8(0x30);
3893 const int8x16_t m32s = vdupq_n_s8(32);
3894
3895 // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
3896 float32x4_t acc_f32[blocklen];
3897
3898 for (int y = 0; y < nr / q8_k_blocklen; y++) {
3899 const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3900
3901 for (int x = 0; x < nc / ncols_interleaved; x++) {
3902 const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
3903
3904 for (int i = 0; i < blocklen; i++) {
3905 acc_f32[i] = vdupq_n_f32(0);
3906 }
3907
3908 for (int b = 0; b < nb; b++) {
3909 int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
3910 for (int i = 0; i < 8; i++) {
3911 acc[i] = vdupq_n_s32(0);
3912 }
3913
3914 // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
3915 // Reused for bias and dequantization later
3916 int16_t q6_scales[16 * 8];
3917 for (int i = 0; i < 16; ++i) {
3918 int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
3919 vst1q_s16(q6_scales + i * 8, s16);
3920 }
3921
3922 // Process two 128-value halves per superblock
3923 for (int half = 0; half < 2; half++) {
3924
3925 const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
3926 const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
3927
3928 // A subblock (sb) is a set of weights that share the scale
3929 // Since q6_K scales are per 16 elements
3930 // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
3931 for (int sb = 0; sb < QK_K / 64; sb++) {
3932 // Q6_K weight index increasing by 64 instead of 32 requires
3933 // loading various q8 memory regions
3934 const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
3935 const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
3936
3937 int8x16_t q8_l_01[2];
3938 int8x16_t q8_l_23[2];
3939 for (int i = 0; i < 2; i++) {
3940 const int offset = i * 32;
3941 q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01)
3942 q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23)
3943 }
3944
3945 int8x16_t q8_h_01[2];
3946 int8x16_t q8_h_23[2];
3947 for (int i = 0; i < 2; i++) {
3948 const int offset = i * 32;
3949 q8_h_01[i] = vld1q_s8(q8_base_h + offset);
3950 q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16);
3951 }
3952
3953 const int ql_off_base = sb * QK_K / 2;
3954
3955 uint8x16_t q6_ql_0[4];
3956 uint8x16_t q6_ql_1[4];
3957 for (int k = 0; k < 4; k++) {
3958 q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
3959 q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
3960 }
3961
3962 const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes
3963 uint8x16_t q6_qh_0[4];
3964 uint8x16_t q6_qh_1[4];
3965 for (int k = 0; k < 4; k++) {
3966 q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
3967 q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
3968 }
3969
3970 // Adjust for the proper high bits (Sb 2 and 3)
3971 if (sb > 1) {
3972 for (int k = 0; k < 4; k++) {
3973 q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
3974 q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
3975 }
3976 }
3977
3978 // Process column pairs (0-1, 2-3, 4-5, 6-7)
3979 for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
3980 const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
3981 const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
3982 const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
3983 const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
3984
3985 // Extract high 2 bits for upper nibble reconstruction
3986 const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
3987 const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
3988
3989 // q6 = (low4 | high2<<4) - 32
3990 // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
3991 const int8x16_t q6_l0 = vsubq_s8(
3992 vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
3993 m32s);
3994 const int8x16_t q6_l1 = vsubq_s8(
3995 vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
3996 m32s);
3997 const int8x16_t q6_h0 = vsubq_s8(
3998 vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
3999 const int8x16_t q6_h1 = vsubq_s8(
4000 vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
4001
4002 // row pair 0, base_l
4003 int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
4004 sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
4005 // row pair 0, base_h
4006 int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
4007 sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
4008 // row pair 1, base_l
4009 int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
4010 sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
4011 // row pair 1, base_h
4012 int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
4013 sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
4014
4015 const int scale_idx_l = half * 8 + sb;
4016 const int scale_idx_h = half * 8 + sb + 4;
4017
4018 const int32x4_t scale_vec_l = {
4019 q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4020 q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4021 q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4022 q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4023 };
4024 const int32x4_t scale_vec_h = {
4025 q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4026 q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4027 q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4028 q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4029 };
4030
4031 acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
4032 acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
4033 acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
4034 acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
4035 }
4036 }
4037 } // for half
4038
4039 // Reorder i8mm output to match memory layout
4040 for (int i = 0; i < 8; i++) {
4041 int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
4042 acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
4043 }
4044 int32x4_t reorder_acc[8] = {
4045 vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
4046 vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
4047 vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
4048 vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
4049 vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
4050 vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
4051 vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
4052 vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
4053 };
4054
4055 // Apply superblock scale (no mins for q6_K)
4056 for (int i = 0; i < q8_k_blocklen; i++) {
4057 for (int j = 0; j < 2; j++) {
4058 float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
4059 float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
4060 const float32x4_t scale = vmulq_f32(q6_d, q8_d);
4061
4062 acc_f32[2 * i + j] =
4063 vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
4064 }
4065 }
4066 } // for b
4067
4068 // Store results
4069 for (int i = 0; i < q8_k_blocklen; i++) {
4070 int row = y * q8_k_blocklen + i;
4071 for (int j = 0; j < 2; j++) {
4072 int col = x * ncols_interleaved + j * 4;
4073 int offset = row * bs + col;
4074 vst1q_f32(s + offset, acc_f32[2 * i + j]);
4075 }
4076 }
4077 } // for x
4078 } // for y
4079 return;
4080#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4081 ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4082}
4083
4084void ggml_gemm_q8_0_4x4_q8_0(int n,
4085 float * GGML_RESTRICT s,
4086 size_t bs,
4087 const void * GGML_RESTRICT vx,
4088 const void * GGML_RESTRICT vy,
4089 int nr,
4090 int nc) {
4091 const int qk = QK8_0;
4092 const int nb = n / qk;
4093 const int ncols_interleaved = 4;
4094 const int blocklen = 4;
4095
4096 assert(n % qk == 0);
4097 assert(nr % 4 == 0);
4098 assert(nc % ncols_interleaved == 0);
4099
4100 UNUSED(nb);
4101 UNUSED(ncols_interleaved);
4102 UNUSED(blocklen);
4103
4104#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4105 for (int y = 0; y < nr / 4; y++) {
4106 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
4107 for (int x = 0; x < nc / ncols_interleaved; x++) {
4108 const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
4109
4110 float32x4_t sumf[4];
4111 for (int m = 0; m < 4; m++) {
4112 sumf[m] = vdupq_n_f32(0);
4113 }
4114
4115 for (int l = 0; l < nb; l++) {
4116 float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
4117 float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
4118
4119 int32x4_t sumi_0 = vdupq_n_s32(0);
4120 int32x4_t sumi_1 = vdupq_n_s32(0);
4121 int32x4_t sumi_2 = vdupq_n_s32(0);
4122 int32x4_t sumi_3 = vdupq_n_s32(0);
4123
4124 for (int k_group = 0; k_group < 8; k_group += 4) {
4125 int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
4126 int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
4127
4128 for (int k = 0; k < 4; k++) {
4129 sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
4130 sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
4131 sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
4132 sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
4133 }
4134 }
4135
4136 sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
4137 sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
4138 sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
4139 sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
4140 }
4141
4142 for (int m = 0; m < 4; m++) {
4143 vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
4144 }
4145 }
4146 }
4147 return;
4148#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4149 ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
4150}
4151
4152void ggml_gemm_q8_0_4x8_q8_0(int n,
4153 float * GGML_RESTRICT s,
4154 size_t bs,
4155 const void * GGML_RESTRICT vx,
4156 const void * GGML_RESTRICT vy,
4157 int nr,
4158 int nc) {
4159 const int qk = QK8_0;
4160 const int nb = n / qk;
4161 const int ncols_interleaved = 4;
4162 const int blocklen = 8;
4163
4164 assert(n % qk == 0);
4165 assert(nr % 4 == 0);
4166 assert(nc % ncols_interleaved == 0);
4167
4168 UNUSED(nb);
4169 UNUSED(ncols_interleaved);
4170 UNUSED(blocklen);
4171
4172#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4173 const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
4174
4175 for (int y = 0; y < nr; y += 4) {
4176 const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
4177
4178 for (int x = 0; x < nc; x += ncols_interleaved) {
4179 const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
4180 const block_q8_0x4 * a_ptr = a_ptr_base;
4181
4182 float32x4_t acc_f32[4];
4183 for (int i = 0; i < 4; i++) {
4184 acc_f32[i] = vdupq_n_f32(0);
4185 }
4186
4187 for (int b = 0; b < nb; b++) {
4188 int32x4_t acc[4];
4189 for (int i = 0; i < 4; i++) {
4190 acc[i] = vdupq_n_s32(0);
4191 }
4192
4193 // Process 4 chunks of 8 positions each
4194 for (int chunk = 0; chunk < 4; chunk++) {
4195 int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
4196 int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
4197 int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
4198 int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
4199
4200 acc[0] = vmmlaq_s32(acc[0], a01, b01);
4201 acc[1] = vmmlaq_s32(acc[1], a01, b23);
4202 acc[2] = vmmlaq_s32(acc[2], a23, b01);
4203 acc[3] = vmmlaq_s32(acc[3], a23, b23);
4204 }
4205
4206 // Reorder outputs from 2×2 tiles to row-major
4207 // acc[0] = [r0c0, r0c1, r1c0, r1c1]
4208 // acc[1] = [r0c2, r0c3, r1c2, r1c3]
4209 // acc[2] = [r2c0, r2c1, r3c0, r3c1]
4210 // acc[3] = [r2c2, r2c3, r3c2, r3c3]
4211 int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
4212 int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
4213 int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
4214 int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
4215
4216 // Scales
4217 float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
4218 float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
4219
4220 acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
4221 acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
4222 acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
4223 acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
4224
4225 a_ptr++;
4226 b_ptr++;
4227 }
4228
4229 for (int row = 0; row < 4; row++) {
4230 vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
4231 }
4232 }
4233 }
4234 return;
4235#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4236 ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
4237}