1#define GGML_COMMON_IMPL_C
2#include "ggml-common.h"
3#include "ggml-quants.h"
4#include "ggml-impl.h"
5#include "ggml-cpu.h"
6#include "simd-mappings.h"
7
8#include "../../quants.h"
9#include "../../ggml-cpu-impl.h"
10
11#include <math.h>
12#include <string.h>
13#include <assert.h>
14#include <float.h>
15#include <stdlib.h> // for qsort
16#include <stdio.h> // for GGML_ASSERT
17
18#define GROUP_MAX_EPS 1e-15f
19#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20#define GROUP_MAX_EPS_IQ2_S 1e-8f
21#define GROUP_MAX_EPS_IQ1_M 1e-7f
22#define GROUP_MAX_EPS_IQ1_S 1e-12f
23
24#define UNUSED GGML_UNUSED
25
26void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
27 assert(QK8_0 == 32);
28 assert(k % QK8_0 == 0);
29 const int nb = k / QK8_0;
30
31 block_q8_0 * GGML_RESTRICT y = vy;
32
33#if defined(__riscv_v)
34
35 size_t vl = QK8_0;
36
37 for (int i = 0; i < nb; i++) {
38 // load elements
39 vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
40
41 vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
42 vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
43 vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
44 float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
45
46 const float d = amax / ((1 << 7) - 1);
47 const float id = d ? 1.0f/d : 0.0f;
48
49 y[i].d = GGML_CPU_FP32_TO_FP16(d);
50
51 vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
52
53 // convert to integer
54 vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
55 vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
56
57 // store result
58 __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
59 }
60#else
61 GGML_UNUSED(nb);
62 // scalar
63 quantize_row_q8_0_ref(x, y, k);
64#endif
65}
66
67void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
68 assert(k % QK8_1 == 0);
69 const int nb = k / QK8_1;
70
71 block_q8_1 * GGML_RESTRICT y = vy;
72
73#if defined(__riscv_v)
74
75 size_t vl = QK8_1;
76
77 for (int i = 0; i < nb; i++) {
78 // load elements
79 vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
80
81 vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
82 vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
83 vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
84 float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
85
86 const float d = amax / ((1 << 7) - 1);
87 const float id = d ? 1.0f/d : 0.0f;
88
89 y[i].d = GGML_CPU_FP32_TO_FP16(d);
90
91 vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
92
93 // convert to integer
94 vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
95 vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
96
97 // store result
98 __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
99
100 // compute sum for y[i].s
101 vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
102 vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
103
104 // set y[i].s
105 int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
106 y[i].s = GGML_CPU_FP32_TO_FP16(sum*d);
107 }
108
109#else
110 GGML_UNUSED(nb);
111 // scalar
112 quantize_row_q8_1_ref(x, y, k);
113#endif
114}
115
116//===================================== Dot products =================================
117
118void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
119#if defined(__riscv_v)
120 const int qk = QK8_0;
121 const int nb = n / qk;
122
123 assert(n % qk == 0);
124 assert(nrc == 1);
125 UNUSED(nrc);
126 UNUSED(bx);
127 UNUSED(by);
128 UNUSED(bs);
129
130 const block_q4_0 * GGML_RESTRICT x = vx;
131 const block_q8_0 * GGML_RESTRICT y = vy;
132
133 int ib = 0;
134 float sumf = 0;
135
136 size_t vl = qk / 2;
137
138 for (; ib < nb; ++ib) {
139 // load elements
140 vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
141
142 vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
143 vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
144
145 // mask and store lower part of x, and then upper part
146 vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
147 vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
148
149 vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
150 vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
151
152 // subtract offset
153 vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
154 vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
155
156 vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
157 vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
158
159 vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
160 vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
161
162 int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
163
164 sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
165 }
166
167 *s = sumf;
168#else
169 ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
170#endif
171}
172
173void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
174#if defined(__riscv_v)
175 const int qk = QK8_1;
176 const int nb = n / qk;
177
178 assert(n % qk == 0);
179 assert(nrc == 1);
180 UNUSED(nrc);
181 UNUSED(bx);
182 UNUSED(by);
183 UNUSED(bs);
184
185 const block_q4_1 * GGML_RESTRICT x = vx;
186 const block_q8_1 * GGML_RESTRICT y = vy;
187
188 int ib = 0;
189 float sumf = 0;
190
191 size_t vl = qk / 2;
192
193 for (; ib < nb; ++ib) {
194 // load elements
195 vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
196
197 vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
198 vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
199
200 // mask and store lower part of x, and then upper part
201 vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
202 vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
203
204 vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
205 vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
206
207 vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
208 vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
209
210 vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
211 vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
212
213 int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
214
215 sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
216 }
217
218 *s = sumf;
219#else
220 ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
221#endif
222}
223
224void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
225#if defined(__riscv_v)
226 const int qk = QK8_0;
227 const int nb = n / qk;
228
229 int ib = 0;
230 float sumf = 0;
231
232 assert(n % qk == 0);
233 assert(qk == QK5_0);
234 assert(nrc == 1);
235 UNUSED(nrc);
236 UNUSED(bx);
237 UNUSED(by);
238 UNUSED(bs);
239
240 const block_q5_0 * GGML_RESTRICT x = vx;
241 const block_q8_0 * GGML_RESTRICT y = vy;
242
243 size_t vl;
244 size_t vlenb = __riscv_vlenb();
245
246 for (; ib < nb; ++ib) {
247 vl = qk / 2;
248 vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
249 vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
250 vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
251 vint8m2_t v0c;
252 if (vlenb == 16) {
253 v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
254 } else {
255 v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
256 v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
257 }
258
259 vl = qk;
260 vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
261 qh = __riscv_vmnand_mm_b4(qh, qh, vl);
262 vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
263 vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
264 vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
265 vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
266 vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
267 int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
268
269 sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
270 }
271
272 *s = sumf;
273#else
274 ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
275#endif
276}
277
278void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
279#if defined(__riscv_v)
280 const int qk = QK8_1;
281 const int nb = n / qk;
282
283 int ib = 0;
284 float sumf = 0;
285
286 assert(n % qk == 0);
287 assert(qk == QK5_1);
288 assert(nrc == 1);
289 UNUSED(nrc);
290 UNUSED(bx);
291 UNUSED(by);
292 UNUSED(bs);
293
294 const block_q5_1 * GGML_RESTRICT x = vx;
295 const block_q8_1 * GGML_RESTRICT y = vy;
296
297 size_t vl;
298 size_t vlenb = __riscv_vlenb();
299
300 for (; ib < nb; ++ib) {
301 vl = qk / 2;
302 vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
303 vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
304 vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
305 vint8m2_t v0c;
306 if (vlenb == 16) {
307 v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
308 } else {
309 v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
310 v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
311 }
312
313 vl = qk;
314 vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
315 vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
316 vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
317 vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
318 vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
319 vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
320 int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
321
322 sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
323 }
324
325 *s = sumf;
326#else
327 ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
328#endif
329}
330
331void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
332 const int qk = QK8_0;
333 const int nb = n / qk;
334
335 assert(n % qk == 0);
336 assert(nrc == 1);
337 UNUSED(nrc);
338 UNUSED(bx);
339 UNUSED(by);
340 UNUSED(bs);
341
342 const block_q8_0 * GGML_RESTRICT x = vx;
343 const block_q8_0 * GGML_RESTRICT y = vy;
344
345 int ib = 0;
346 float sumf = 0;
347
348#if defined(__riscv_v)
349 size_t vl = qk;
350
351 for (; ib < nb; ++ib) {
352 // load elements
353 vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
354 vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
355
356 vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
357
358 vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
359 vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
360
361 int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
362
363 sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
364 }
365
366 *s = sumf;
367#else
368
369 UNUSED(nb);
370 UNUSED(x);
371 UNUSED(y);
372 UNUSED(ib);
373 UNUSED(sumf);
374
375 ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
376#endif
377}
378
379void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
380 assert(nrc == 1);
381 UNUSED(nrc);
382 UNUSED(bx);
383 UNUSED(by);
384 UNUSED(bs);
385
386 const block_q2_K * GGML_RESTRICT x = vx;
387 const block_q8_K * GGML_RESTRICT y = vy;
388
389 const int nb = n / QK_K;
390
391#if defined __riscv_xtheadvector
392
393 float sumf = 0;
394 uint8_t atmp[16];
395
396 for (int i = 0; i < nb; ++i) {
397 const uint8_t * q2 = x[i].qs;
398 const int8_t * q8 = y[i].qs;
399 const uint8_t * sc = x[i].scales;
400 const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
401 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
402 uint8_t *patmp = atmp;
403 int vsums;
404 int tmp;
405 __asm__ __volatile__(
406 "th.vsetvli zero, %[vl16], e8, m1\n\t"
407 "th.vmv.v.x v8, zero\n\t"
408 "th.vlb.v v1, (%[sc])\n\t"
409 "th.vand.vi v0, v1, 0xF\n\t"
410 "th.vsrl.vi v1, v1, 4\n\t"
411 "th.vsb.v v0, (%[scale])\n\t"
412 "th.vwaddu.vx v16, v1, zero\n\t"
413 "th.vsetvli zero, %[vl16], e16, m2\n\t"
414 "th.vlh.v v2, (%[bsums])\n\t"
415 "th.vwmul.vv v4, v16, v2\n\t"
416 "th.vsetvli zero, %[vl16], e32, m4\n\t"
417 "th.vredsum.vs v8, v4, v8\n\t"
418 "th.vmv.x.s %[vsums], v8"
419 : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
420 : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
421 , [vl16] "r" (16)
422 : "memory"
423 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
424 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
425 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
426 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
427 );
428 sumf += dmin * vsums;
429 int isum = 0;
430
431 for (int j = 0; j < QK_K/128; ++j) {
432 __asm__ __volatile__(
433 "th.vsetvli zero, %[vl32], e8, m2\n\t"
434 "th.vlb.v v0, (%[q2])\n\t"
435 "th.vsrl.vi v2, v0, 2\n\t"
436 "th.vsrl.vi v4, v0, 4\n\t"
437 "th.vsrl.vi v6, v0, 6\n\t"
438 "th.vand.vi v0, v0, 0x3\n\t"
439 "th.vand.vi v2, v2, 0x3\n\t"
440 "th.vand.vi v4, v4, 0x3\n\t"
441 "th.vsetvli zero, %[vl128], e8, m8\n\t"
442 "th.vlb.v v8, (%[q8])\n\t"
443 "th.vsetvli zero, %[vl64], e8, m4\n\t"
444 "th.vwmul.vv v16, v0, v8\n\t"
445 "th.vwmul.vv v24, v4, v12\n\t"
446 "th.vsetvli zero, %[vl16], e16, m2\n\t"
447 "th.vmv.v.x v0, zero\n\t"
448 "th.vwredsum.vs v10, v16, v0\n\t"
449 "th.vwredsum.vs v9, v18, v0\n\t"
450 "th.vwredsum.vs v8, v20, v0\n\t"
451 "th.vwredsum.vs v7, v22, v0\n\t"
452 "th.vwredsum.vs v11, v24, v0\n\t"
453 "th.vwredsum.vs v12, v26, v0\n\t"
454 "th.vwredsum.vs v13, v28, v0\n\t"
455 "th.vwredsum.vs v14, v30, v0\n\t"
456 "li %[tmp], 4\n\t"
457 "th.vsetvli zero, %[tmp], e32, m1\n\t"
458 "th.vslideup.vi v10, v9, 1\n\t"
459 "th.vslideup.vi v8, v7, 1\n\t"
460 "th.vslideup.vi v11, v12, 1\n\t"
461 "th.vslideup.vi v13, v14, 1\n\t"
462 "th.vslideup.vi v10, v8, 2\n\t"
463 "th.vslideup.vi v11, v13, 2\n\t"
464 "li %[tmp], 8\n\t"
465 "th.vsetvli zero, %[tmp], e32, m2\n\t"
466 "th.vlbu.v v12, (%[scale])\n\t"
467 "th.vmul.vv v10, v10, v12\n\t"
468 "th.vredsum.vs v0, v10, v0\n\t"
469 "th.vmv.x.s %[tmp], v0\n\t"
470 "add %[isum], %[isum], %[tmp]"
471 : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
472 : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
473 , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
474 : "memory"
475 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
476 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
477 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
478 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
479 );
480 q2 += 32; q8 += 128; patmp += 8;
481 }
482
483 sumf += dall * isum;
484 }
485
486 *s = sumf;
487
488#elif defined __riscv_v
489
490 float sumf = 0;
491 uint8_t atmp[16];
492
493 const int vector_length = __riscv_vlenb() * 8;
494 uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
495 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
496
497 switch (vector_length) {
498 case 256:
499 for (int i = 0; i < nb; ++i) {
500 const uint8_t * q2 = x[i].qs;
501 const int8_t * q8 = y[i].qs;
502 const uint8_t * sc = x[i].scales;
503
504 const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
505 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
506
507 size_t vl = 16;
508
509 vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
510 vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
511
512 vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
513
514 vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
515 vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
516 vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
517 vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
518 vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
519
520 sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
521
522 vl = 32;
523
524 vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
525 vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
526
527 uint8_t is = 0;
528 int isum = 0;
529
530 for (int j = 0; j < QK_K / 128; ++j) {
531 // load Q2
532 vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
533
534 vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
535 vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
536 vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
537 vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
538
539 // duplicate scale elements for product
540 vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
541 vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
542 vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
543 vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
544
545 vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
546 vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
547 vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
548 vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
549
550 // load Q8
551 vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
552 vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
553 vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
554 vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
555
556 vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
557 vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
558 vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
559 vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
560
561 vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
562 vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
563
564 isum += __riscv_vmv_x_s_i32m1_i32(isum1);
565
566 q2 += 32;
567 q8 += 128;
568 is = 8;
569 }
570
571 sumf += dall * isum;
572 }
573 break;
574 case 128:
575 for (int i = 0; i < nb; ++i) {
576 const uint8_t * q2 = x[i].qs;
577 const int8_t * q8 = y[i].qs;
578 const uint8_t * sc = x[i].scales;
579 const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
580 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
581 uint8_t *patmp = atmp;
582 int vsums;
583 int tmp, t1, t2, t3, t4, t5, t6, t7;
584 __asm__ __volatile__(
585 "vsetivli zero, 16, e8, m1\n\t"
586 "vmv.v.x v8, zero\n\t"
587 "lb zero, 15(%[sc])\n\t"
588 "vle8.v v1, (%[sc])\n\t"
589 "vle8.v v2, (%[bsums])\n\t"
590 "addi %[tmp], %[bsums], 16\n\t"
591 "vand.vi v0, v1, 0xF\n\t"
592 "vsrl.vi v1, v1, 4\n\t"
593 "vle8.v v3, (%[tmp])\n\t"
594 "vse8.v v0, (%[scale])\n\t"
595 "vsetivli zero, 16, e16, m2\n\t"
596 "vzext.vf2 v0, v1\n\t"
597 "vwmul.vv v4, v0, v2\n\t"
598 "vsetivli zero, 16, e32, m4\n\t"
599 "vredsum.vs v8, v4, v8\n\t"
600 "vmv.x.s %[vsums], v8"
601 : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
602 : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
603 : "memory"
604 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
605 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
606 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
607 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
608 );
609 sumf += dmin * vsums;
610 int isum = 0;
611
612 for (int j = 0; j < QK_K/128; ++j) {
613 __asm__ __volatile__(
614 "lb zero, 31(%[q2])\n\t"
615 "addi %[tmp], %[q2], 16\n\t"
616 "addi %[t1], %[q8], 16\n\t"
617 "vsetivli zero, 16, e8, m1\n\t"
618 "vle8.v v0, (%[q2])\n\t"
619 "vle8.v v1, (%[tmp])\n\t"
620 "vsrl.vi v2, v0, 2\n\t"
621 "vsrl.vi v3, v1, 2\n\t"
622 "vsrl.vi v4, v0, 4\n\t"
623 "addi %[tmp], %[q8], 32\n\t"
624 "vle8.v v8, (%[q8])\n\t"
625 "vle8.v v9, (%[t1])\n\t"
626 "addi %[t1], %[t1], 32\n\t"
627 "vsrl.vi v5, v1, 4\n\t"
628 "vsrl.vi v6, v0, 6\n\t"
629 "vsrl.vi v7, v1, 6\n\t"
630 "vle8.v v10, (%[tmp])\n\t"
631 "vle8.v v11, (%[t1])\n\t"
632 "addi %[tmp], %[tmp], 32\n\t"
633 "addi %[t1], %[t1], 32\n\t"
634 "vand.vi v0, v0, 0x3\n\t"
635 "vand.vi v1, v1, 0x3\n\t"
636 "vand.vi v2, v2, 0x3\n\t"
637 "vle8.v v12, (%[tmp])\n\t"
638 "vle8.v v13, (%[t1])\n\t"
639 "addi %[tmp], %[tmp], 32\n\t"
640 "addi %[t1], %[t1], 32\n\t"
641 "vand.vi v3, v3, 0x3\n\t"
642 "vand.vi v4, v4, 0x3\n\t"
643 "vand.vi v5, v5, 0x3\n\t"
644 "vle8.v v14, (%[tmp])\n\t"
645 "vle8.v v15, (%[t1])\n\t"
646 "vwmul.vv v16, v0, v8\n\t"
647 "vwmul.vv v18, v1, v9\n\t"
648 "vwmul.vv v20, v2, v10\n\t"
649 "vwmul.vv v22, v3, v11\n\t"
650 "vwmul.vv v24, v4, v12\n\t"
651 "vwmul.vv v26, v5, v13\n\t"
652 "vwmul.vv v28, v6, v14\n\t"
653 "vwmul.vv v30, v7, v15\n\t"
654 "vsetivli zero, 8, e16, m1\n\t"
655 "vmv.v.x v0, zero\n\t"
656 "lbu %[tmp], 0(%[scale])\n\t"
657 "vwredsum.vs v8, v16, v0\n\t"
658 "vwredsum.vs v9, v18, v0\n\t"
659 "lbu %[t1], 1(%[scale])\n\t"
660 "vwredsum.vs v10, v20, v0\n\t"
661 "vwredsum.vs v11, v22, v0\n\t"
662 "lbu %[t2], 2(%[scale])\n\t"
663 "vwredsum.vs v12, v24, v0\n\t"
664 "vwredsum.vs v13, v26, v0\n\t"
665 "lbu %[t3], 3(%[scale])\n\t"
666 "vwredsum.vs v14, v28, v0\n\t"
667 "vwredsum.vs v15, v30, v0\n\t"
668 "lbu %[t4], 4(%[scale])\n\t"
669 "vwredsum.vs v8, v17, v8\n\t"
670 "vwredsum.vs v9, v19, v9\n\t"
671 "lbu %[t5], 5(%[scale])\n\t"
672 "vwredsum.vs v10, v21, v10\n\t"
673 "vwredsum.vs v11, v23, v11\n\t"
674 "lbu %[t6], 6(%[scale])\n\t"
675 "vwredsum.vs v12, v25, v12\n\t"
676 "vwredsum.vs v13, v27, v13\n\t"
677 "lbu %[t7], 7(%[scale])\n\t"
678 "vwredsum.vs v14, v29, v14\n\t"
679 "vwredsum.vs v15, v31, v15\n\t"
680 "vsetivli zero, 4, e32, m1\n\t"
681 "vmul.vx v0, v8, %[tmp]\n\t"
682 "vmul.vx v1, v9, %[t1]\n\t"
683 "vmacc.vx v0, %[t2], v10\n\t"
684 "vmacc.vx v1, %[t3], v11\n\t"
685 "vmacc.vx v0, %[t4], v12\n\t"
686 "vmacc.vx v1, %[t5], v13\n\t"
687 "vmacc.vx v0, %[t6], v14\n\t"
688 "vmacc.vx v1, %[t7], v15\n\t"
689 "vmv.x.s %[tmp], v0\n\t"
690 "vmv.x.s %[t1], v1\n\t"
691 "add %[isum], %[isum], %[tmp]\n\t"
692 "add %[isum], %[isum], %[t1]"
693 : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
694 , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
695 , [isum] "+&r" (isum)
696 : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
697 : "memory"
698 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
699 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
700 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
701 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
702 );
703 q2 += 32; q8 += 128; patmp += 8;
704 }
705
706 sumf += dall * isum;
707 }
708 break;
709 default:
710 assert(false && "Unsupported vector length");
711 break;
712 }
713
714 *s = sumf;
715
716#else
717
718 UNUSED(x);
719 UNUSED(y);
720 UNUSED(nb);
721
722 ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
723#endif
724}
725
726void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
727 assert(n % QK_K == 0);
728 assert(nrc == 1);
729 UNUSED(nrc);
730 UNUSED(bx);
731 UNUSED(by);
732 UNUSED(bs);
733
734 const uint32_t kmask1 = 0x03030303;
735 const uint32_t kmask2 = 0x0f0f0f0f;
736
737 const block_q3_K * GGML_RESTRICT x = vx;
738 const block_q8_K * GGML_RESTRICT y = vy;
739
740 const int nb = n / QK_K;
741
742#if defined __riscv_xtheadvector
743
744 uint32_t utmp[4];
745 float sumf = 0;
746
747 for (int i = 0; i < nb; ++i) {
748 const uint8_t * restrict q3 = x[i].qs;
749 const uint8_t * restrict qh = x[i].hmask;
750 const int8_t * restrict q8 = y[i].qs;
751
752 int8_t * scale = (int8_t *)utmp;
753 int tmp;
754 __asm__ __volatile__(
755 "li %[tmp], 12\n\t"
756 "th.vsetvli zero, %[tmp], e8, m1\n\t"
757 "th.vlb.v v0, (%[s6b])\n\t"
758 "th.vmv.v.v v2, v0\n\t"
759 "li %[tmp], 2\n\t"
760 "th.vsetvli zero, %[tmp], e64, m1\n\t"
761 "th.vmv.v.x v9, %[sh]\n\t"\
762 "th.vslidedown.vi v1, v0, 1\n\t"
763 "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
764 "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
765 "li %[tmp], 4\n\t"
766 "th.vsetvli zero, %[tmp], e32, m1\n\t"
767 "th.vid.v v9\n\t"
768 "th.vmv.x.s %[tmp], v1\n\t"
769 "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
770 "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
771 "th.vsrl.vv v4, v1, v9\n\t"
772 "th.vsrl.vv v2, v0, v8\n\t"
773 "th.vand.vx v5, v4, %[kmask1]\n\t"
774 "th.vand.vx v3, v2, %[kmask2]\n\t"
775 "th.vsll.vi v6, v5, 4\n\t"
776 "th.vor.vv v7, v6, v3\n\t"
777 "li %[tmp], 16\n\t"
778 "th.vsetvli zero, %[tmp], e8, m1\n\t"
779 "th.vsub.vx v0, v7, %[c]\n\t"
780 "th.vsb.v v0, (%[scale])"
781 : [tmp] "=&r" (tmp)
782 : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
783 , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
784 : "memory"
785 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
786 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
787 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
788 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
789 );
790
791 uint8_t m = 1;
792 int isum = 0;
793 for (int j = 0; j < QK_K; j += 128) {
794 __asm__ __volatile__(
795 // fixme: use v0p7 mask layout directly
796 "th.vsetvli zero, %[vl32], e8, m2\n\t"
797 "th.vlb.v v8, (%[q3])\n\t"
798 "th.vsrl.vi v10, v8, 2\n\t"
799 "th.vsrl.vi v12, v8, 4\n\t"
800 "th.vsrl.vi v14, v8, 6\n\t"
801 "th.vand.vi v8, v8, 3\n\t"
802 "th.vand.vi v10, v10, 3\n\t"
803 "th.vand.vi v12, v12, 3\n\t"
804 "th.vlb.v v2, (%[qh])\n\t"
805 "th.vand.vx v4, v2, %[m]\n\t"
806 "slli %[m], %[m], 1\n\t"
807 "th.vmseq.vx v0, v4, zero\n\t"
808 "th.vadd.vi v8, v8, -4, v0.t\n\t"
809 "th.vand.vx v4, v2, %[m]\n\t"
810 "slli %[m], %[m], 1\n\t"
811 "th.vmseq.vx v0, v4, zero\n\t"
812 "th.vadd.vi v10, v10, -4, v0.t\n\t"
813 "th.vand.vx v4, v2, %[m]\n\t"
814 "slli %[m], %[m], 1\n\t"
815 "th.vmseq.vx v0, v4, zero\n\t"
816 "th.vadd.vi v12, v12, -4, v0.t\n\t"
817 "th.vand.vx v4, v2, %[m]\n\t"
818 "slli %[m], %[m], 1\n\t"
819 "th.vmseq.vx v0, v4, zero\n\t"
820 "th.vadd.vi v14, v14, -4, v0.t\n\t"
821 "th.vsetvli zero, %[vl128], e8, m8\n\t"
822 "th.vlb.v v0, (%[q8])\n\t"
823 "th.vsetvli zero, %[vl64], e8, m4\n\t"
824 "th.vwmul.vv v16, v0, v8\n\t"
825 "th.vwmul.vv v24, v4, v12\n\t"
826 "li %[tmp], 16\n\t"
827 "th.vsetvli zero, %[tmp], e16, m2\n\t"
828 "th.vmv.v.x v0, zero\n\t"
829 "th.vwredsum.vs v10, v16, v0\n\t"
830 "th.vwredsum.vs v9, v18, v0\n\t"
831 "th.vwredsum.vs v8, v20, v0\n\t"
832 "th.vwredsum.vs v7, v22, v0\n\t"
833 "th.vwredsum.vs v11, v24, v0\n\t"
834 "th.vwredsum.vs v12, v26, v0\n\t"
835 "th.vwredsum.vs v13, v28, v0\n\t"
836 "th.vwredsum.vs v14, v30, v0\n\t"
837 "li %[tmp], 4\n\t"
838 "th.vsetvli zero, %[tmp], e32, m1\n\t"
839 "th.vslideup.vi v10, v9, 1\n\t"
840 "th.vslideup.vi v8, v7, 1\n\t"
841 "th.vslideup.vi v11, v12, 1\n\t"
842 "th.vslideup.vi v13, v14, 1\n\t"
843 "th.vslideup.vi v10, v8, 2\n\t"
844 "th.vslideup.vi v11, v13, 2\n\t"
845 "li %[tmp], 8\n\t"
846 "th.vsetvli zero, %[tmp], e32, m2\n\t"
847 "th.vlb.v v12, (%[scale])\n\t"
848 "th.vmul.vv v10, v10, v12\n\t"
849 "th.vredsum.vs v0, v10, v0\n\t"
850 "th.vmv.x.s %[tmp], v0\n\t"
851 "add %[isum], %[isum], %[tmp]"
852 : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
853 : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
854 , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
855 : "memory"
856 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
857 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
858 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
859 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
860 );
861 q3 += 32; q8 += 128; scale += 8;
862 }
863
864 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
865 sumf += d * isum;
866 }
867
868 *s = sumf;
869
870#elif defined __riscv_v
871
872 uint32_t utmp[4];
873 float sumf = 0;
874 uint32_t aux[3];
875 const int vector_length = __riscv_vlenb() * 8;
876
877 switch (vector_length) {
878 case 256:
879 for (int i = 0; i < nb; ++i) {
880
881 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
882 const uint8_t * GGML_RESTRICT qh = x[i].hmask;
883 const int8_t * GGML_RESTRICT q8 = y[i].qs;
884
885 memcpy(aux, x[i].scales, 12);
886 utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
887 utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
888 utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
889 utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
890
891 int8_t * scale = (int8_t *)utmp;
892 for (int j = 0; j < 16; ++j) scale[j] -= 32;
893
894
895 size_t vl = 32;
896 uint8_t m = 1;
897
898 vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
899 vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
900
901 int sum_t = 0;
902
903 for (int j = 0; j < QK_K; j += 128) {
904
905 vl = 32;
906
907 // load Q3
908 vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
909
910 vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
911 vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
912 vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
913 vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
914
915 // compute mask for subtraction
916 vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
917 vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
918 vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
919 m <<= 1;
920
921 vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
922 vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
923 vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
924 m <<= 1;
925
926 vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
927 vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
928 vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
929 m <<= 1;
930
931 vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
932 vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
933 vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
934 m <<= 1;
935
936 // load Q8 and take product with Q3
937 vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
938 vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
939 vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
940 vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
941
942 vl = 16;
943
944 // retrieve lane to multiply with scale
945 vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
946 vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
947 vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
948 vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
949 vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
950 vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
951 vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
952 vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
953
954 vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
955 vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
956 vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
957 vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
958
959 sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
960
961 q3 += 32; q8 += 128; scale += 8;
962
963 }
964
965 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
966
967 sumf += d*sum_t;
968
969 }
970 break;
971 case 128:
972 for (int i = 0; i < nb; ++i) {
973 const uint8_t * restrict q3 = x[i].qs;
974 const uint8_t * restrict qh = x[i].hmask;
975 const int8_t * restrict q8 = y[i].qs;
976
977 int8_t * scale = (int8_t *)utmp;
978 int tmp, t1, t2, t3, t4, t5, t6, t7;
979 __asm__ __volatile__(
980 "vsetivli zero, 12, e8, m1\n\t"
981 "vle8.v v0, (%[s6b])\n\t"
982 "vmv1r.v v2, v0\n\t"
983 "vsetivli zero, 2, e64, m1\n\t"
984 "vmv.v.x v9, %[sh]\n\t"\
985 "vslidedown.vi v1, v0, 1\n\t"
986 "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
987 "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
988 "vsetivli zero, 4, e32, m1\n\t"
989 "vid.v v9\n\t"
990 "vmv.x.s %[tmp], v1\n\t"
991 "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
992 "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
993 "vsrl.vv v4, v1, v9\n\t"
994 "vsrl.vv v2, v0, v8\n\t"
995 "vand.vx v5, v4, %[kmask1]\n\t"
996 "vand.vx v3, v2, %[kmask2]\n\t"
997 "vsll.vi v6, v5, 4\n\t"
998 "vor.vv v7, v6, v3\n\t"
999 "vsetivli zero, 16, e8, m1\n\t"
1000 "vsub.vx v0, v7, %[c]\n\t"
1001 "vse8.v v0, (%[scale])"
1002 : [tmp] "=&r" (tmp)
1003 : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
1004 , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
1005 : "memory"
1006 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1007 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1008 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1009 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1010 );
1011
1012 uint8_t m = 1;
1013 int isum = 0;
1014 for (int j = 0; j < QK_K; j += 128) {
1015 __asm__ __volatile__(
1016 "lb zero, 31(%[q3])\n\t"
1017 "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
1018 "vle8.v v8, (%[q3])\n\t"
1019 "vsrl.vi v10, v8, 2\n\t"
1020 "vsrl.vi v12, v8, 4\n\t"
1021 "vsrl.vi v14, v8, 6\n\t"
1022 "lb zero, 64(%[q8])\n\t"
1023 "vand.vi v8, v8, 3\n\t"
1024 "vand.vi v10, v10, 3\n\t"
1025 "vand.vi v12, v12, 3\n\t"
1026 "vle8.v v2, (%[qh])\n\t"
1027 "lb zero, 127(%[q8])\n\t"
1028 "vand.vx v4, v2, %[m]\n\t"
1029 "slli %[m], %[m], 1\n\t"
1030 "vmseq.vx v0, v4, zero\n\t"
1031 "vadd.vi v8, v8, -4, v0.t\n\t"
1032 "lb zero, 0(%[q8])\n\t"
1033 "vand.vx v4, v2, %[m]\n\t"
1034 "slli %[m], %[m], 1\n\t"
1035 "vmseq.vx v0, v4, zero\n\t"
1036 "vadd.vi v10, v10, -4, v0.t\n\t"
1037 "vand.vx v4, v2, %[m]\n\t"
1038 "slli %[m], %[m], 1\n\t"
1039 "vmseq.vx v0, v4, zero\n\t"
1040 "vadd.vi v12, v12, -4, v0.t\n\t"
1041 "vand.vx v4, v2, %[m]\n\t"
1042 "slli %[m], %[m], 1\n\t"
1043 "vmseq.vx v0, v4, zero\n\t"
1044 "vadd.vi v14, v14, -4, v0.t\n\t"
1045 "vsetvli zero, %[vl128], e8, m8\n\t"
1046 "vle8.v v0, (%[q8])\n\t"
1047 "lb %[tmp], 0(%[scale])\n\t"
1048 "lb %[t1], 1(%[scale])\n\t"
1049 "lb %[t2], 2(%[scale])\n\t"
1050 "lb %[t3], 3(%[scale])\n\t"
1051 "vsetvli zero, %[vl64], e8, m4\n\t"
1052 "vwmul.vv v16, v0, v8\n\t"
1053 "vwmul.vv v24, v4, v12\n\t"
1054 "vsetivli zero, 16, e16, m2\n\t"
1055 "vmv.v.x v0, zero\n\t"
1056 "vwredsum.vs v8, v16, v0\n\t"
1057 "lb %[t4], 4(%[scale])\n\t"
1058 "lb %[t5], 5(%[scale])\n\t"
1059 "vwredsum.vs v9, v18, v0\n\t"
1060 "vwredsum.vs v10, v20, v0\n\t"
1061 "vwredsum.vs v11, v22, v0\n\t"
1062 "vwredsum.vs v12, v24, v0\n\t"
1063 "lb %[t6], 6(%[scale])\n\t"
1064 "lb %[t7], 7(%[scale])\n\t"
1065 "vwredsum.vs v13, v26, v0\n\t"
1066 "vwredsum.vs v14, v28, v0\n\t"
1067 "vwredsum.vs v15, v30, v0\n\t"
1068 "vsetivli zero, 4, e32, m1\n\t"
1069 "vmul.vx v0, v8, %[tmp]\n\t"
1070 "vmul.vx v1, v9, %[t1]\n\t"
1071 "vmacc.vx v0, %[t2], v10\n\t"
1072 "vmacc.vx v1, %[t3], v11\n\t"
1073 "vmacc.vx v0, %[t4], v12\n\t"
1074 "vmacc.vx v1, %[t5], v13\n\t"
1075 "vmacc.vx v0, %[t6], v14\n\t"
1076 "vmacc.vx v1, %[t7], v15\n\t"
1077 "vmv.x.s %[tmp], v0\n\t"
1078 "vmv.x.s %[t1], v1\n\t"
1079 "add %[isum], %[isum], %[tmp]\n\t"
1080 "add %[isum], %[isum], %[t1]"
1081 : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
1082 , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
1083 , [m] "+&r" (m), [isum] "+&r" (isum)
1084 : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
1085 , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
1086 : "memory"
1087 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1088 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1089 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1090 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1091 );
1092 q3 += 32; q8 += 128; scale += 8;
1093 }
1094
1095 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1096 sumf += d * isum;
1097 }
1098 break;
1099 default:
1100 assert(false && "Unsupported vector length");
1101 break;
1102 }
1103
1104 *s = sumf;
1105
1106#else
1107
1108 UNUSED(kmask1);
1109 UNUSED(kmask2);
1110 UNUSED(x);
1111 UNUSED(y);
1112 UNUSED(nb);
1113
1114 ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1115#endif
1116
1117}
1118
1119void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1120 assert(n % QK_K == 0);
1121 assert(nrc == 1);
1122 UNUSED(nrc);
1123 UNUSED(bx);
1124 UNUSED(by);
1125 UNUSED(bs);
1126
1127 const block_q4_K * GGML_RESTRICT x = vx;
1128 const block_q8_K * GGML_RESTRICT y = vy;
1129
1130 const int nb = n / QK_K;
1131
1132 static const uint32_t kmask1 = 0x3f3f3f3f;
1133 static const uint32_t kmask2 = 0x0f0f0f0f;
1134 static const uint32_t kmask3 = 0x03030303;
1135
1136 uint32_t utmp[4];
1137
1138#if defined __riscv_xtheadvector
1139
1140 const uint8_t * scales = (const uint8_t*)&utmp[0];
1141 const uint8_t * mins = (const uint8_t*)&utmp[2];
1142
1143 float sumf = 0;
1144
1145 for (int i = 0; i < nb; ++i) {
1146 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1147 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1148
1149 int tmp, tmp2, sumi;
1150 __asm__ __volatile__(
1151 "li %[t1], 12\n\t"
1152 "th.vsetvli zero, %[t1], e8, m1\n\t"
1153 "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1154 "li %[t1], 4\n\t"
1155 "th.vsetvli zero, %[t1], e32, m1\n\t"
1156 "th.vslidedown.vi v2, v1, 2\n\t"
1157 "th.vmv.v.v v3, v2\n\t"
1158 "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1159 "li %[t1], 2\n\t"
1160 "th.vsetvli zero, %[t1], e32, m1\n\t"
1161 "th.vmv.v.i v4, 4\n\t"
1162 "th.vand.vx v8, v1, %[kmask1]\n\t"
1163 "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1164 "th.vsrl.vi v6, v1, 6\n\t"
1165 "th.vsrl.vv v7, v2, v5\n\t"
1166 "th.vand.vx v0, v6, %[kmask3]\n\t"
1167 "th.vand.vx v2, v7, %[kmask2]\n\t"
1168 "th.vsll.vi v6, v0, 4\n\t"
1169 "li %[t2], 8\n\t"
1170 "addi %[t1], %[utmp], 4\n\t"
1171 "th.vor.vv v1, v6, v2\n\t"
1172 "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
1173 "th.vssw.v v1, (%[t1]), %[t2]\n\t"
1174 "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
1175 "th.vlw.v v2, (%[bsums])\n\t"
1176 "th.vsetvli zero, %[t2], e16, m1\n\t"
1177 "th.vnsrl.vi v0, v2, 0\n\t"
1178 "th.vnsrl.vi v1, v2, 16\n\t"
1179 "th.vadd.vv v2, v0, v1\n\t"
1180 "th.vlbu.v v4, (%[mins])\n\t"
1181 "th.vwmul.vv v6, v4, v2\n\t"
1182 "th.vmv.v.x v0, zero\n\t"
1183 "th.vsetvli zero, %[t2], e32, m2\n\t"
1184 "th.vredsum.vs v0, v6, v0\n\t"
1185 "th.vmv.x.s %[sumi], v0"
1186 : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1187 : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1188 , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1189 , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1190 : "memory"
1191 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1192 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1193 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1194 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1195 );
1196 sumf -= dmin * sumi;
1197
1198 const uint8_t * restrict q4 = x[i].qs;
1199 const int8_t * restrict q8 = y[i].qs;
1200
1201 sumi = 0;
1202 const uint8_t * scale = scales;
1203
1204 for (int j = 0; j < QK_K/128; ++j) {
1205 int vl128 = 128, vl64 = 64, vl32 = 32;
1206 __asm__ __volatile__(
1207 "th.vsetvli zero, %[vl128], e8, m8\n\t"
1208 "th.vlb.v v8, (%[q8])\n\t"
1209 "th.vsetvli zero, %[vl64], e8, m4\n\t"
1210 "th.vlb.v v0, (%[q4])\n\t"
1211 "th.vsrl.vi v4, v0, 4\n\t"
1212 "th.vand.vi v0, v0, 0xF\n\t"
1213 "th.vsetvli zero, %[vl32], e8, m2\n\t"
1214 "th.vwmul.vv v28, v6, v14\n\t"
1215 "th.vwmul.vv v20, v4, v10\n\t"
1216 "th.vwmul.vv v24, v2, v12\n\t"
1217 "th.vwmul.vv v16, v0, v8\n\t"
1218 "li %[tmp], 4\n\t"
1219 "th.vsetvli zero, %[tmp], e32, m1\n\t"
1220 "th.vlbu.v v1, (%[scale])\n\t"
1221 "th.vmv.v.x v0, zero\n\t"
1222 "th.vsetvli zero, %[vl32], e16, m4\n\t"
1223 "th.vwredsum.vs v6, v24, v0\n\t"
1224 "th.vwredsum.vs v7, v28, v0\n\t"
1225 "th.vwredsum.vs v4, v16, v0\n\t"
1226 "th.vwredsum.vs v5, v20, v0\n\t"
1227 "th.vsetvli zero, %[tmp], e32, m1\n\t"
1228 "th.vslideup.vi v6, v7, 1\n\t"
1229 "th.vslideup.vi v4, v5, 1\n\t"
1230 "th.vslideup.vi v4, v6, 2\n\t"
1231 "th.vmul.vv v8, v4, v1\n\t"
1232 "th.vredsum.vs v0, v8, v0\n\t"
1233 "th.vmv.x.s %[tmp], v0\n\t"
1234 "add %[sumi], %[sumi], %[tmp]"
1235 : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1236 : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1237 , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1238 : "memory"
1239 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1240 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1241 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1242 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1243 );
1244
1245 q4 += 64; q8 += 128; scale += 4;
1246 }
1247
1248 sumf += d * sumi;
1249
1250 }
1251
1252 *s = sumf;
1253
1254#elif defined __riscv_v
1255
1256 const uint8_t * scales = (const uint8_t*)&utmp[0];
1257 const uint8_t * mins = (const uint8_t*)&utmp[2];
1258
1259 float sumf = 0;
1260 const int vector_length = __riscv_vlenb() * 8;
1261
1262 switch (vector_length) {
1263 case 256:
1264 for (int i = 0; i < nb; ++i) {
1265
1266 size_t vl = 8;
1267
1268 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1269 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1270
1271 vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
1272 vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
1273 vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
1274
1275 memcpy(utmp, x[i].scales, 12);
1276 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1277 const uint32_t uaux = utmp[1] & kmask1;
1278 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1279 utmp[2] = uaux;
1280 utmp[0] &= kmask1;
1281
1282 vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
1283 vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
1284 vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
1285
1286 vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1287 sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1288
1289 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1290 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1291
1292 vl = 32;
1293
1294 int32_t sum_1 = 0;
1295 int32_t sum_2 = 0;
1296
1297 vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
1298
1299 for (int j = 0; j < QK_K/64; ++j) {
1300 // load Q4
1301 vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
1302
1303 // load Q8 and multiply it with lower Q4 nibble
1304 vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
1305 vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
1306 vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
1307 vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
1308
1309 sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
1310
1311 // load Q8 and multiply it with upper Q4 nibble
1312 vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
1313 vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
1314 vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
1315 vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
1316
1317 sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
1318
1319 q4 += 32; q8 += 64;
1320
1321 }
1322
1323 sumf += d*(sum_1 + sum_2);
1324
1325 }
1326 break;
1327 case 128:
1328 for (int i = 0; i < nb; ++i) {
1329 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1330 const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1331
1332 float ftmp, ft2;
1333 const uint8_t * restrict q40;
1334 const uint8_t * restrict q41;
1335 const uint8_t * restrict q42;
1336 const uint8_t * restrict q43;
1337 const int8_t * restrict q80;
1338 const int8_t * restrict q81;
1339 const int8_t * restrict q82;
1340 const int8_t * restrict q83;
1341 int s0, s1, s2, s3;
1342
1343 __asm__ __volatile__(
1344 "li %[s1], 8\n\t"
1345 "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1346 "vle32.v v1, (%[s6b])\n\t"
1347 "vslide1down.vx v1, v1, zero\n\t"
1348 "vmv.v.x v16, zero\n\t"
1349 "vslidedown.vi v2, v1, 2\n\t"
1350 "vmv1r.v v3, v2\n\t"
1351 "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1352 "vsetivli zero, 2, e32, m1, ta, ma\n\t"
1353 "vmv.v.i v4, 4\n\t"
1354 "vand.vx v8, v1, %[kmask1]\n\t"
1355 "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1356 "vsrl.vi v6, v1, 6\n\t"
1357 "vsrl.vv v7, v2, v5\n\t"
1358 "vsse32.v v8, (%[utmp]), %[s1]\n\t"
1359 "vand.vx v0, v6, %[kmask3]\n\t"
1360 "vand.vx v2, v7, %[kmask2]\n\t"
1361 "vsll.vi v6, v0, 4\n\t"
1362 "addi %[s0], %[utmp], 4\n\t"
1363 "vor.vv v1, v6, v2\n\t"
1364 "vsse32.v v1, (%[s0]), %[s1]\n\t"
1365 "vsetivli zero, 8, e16, m1, ta, ma\n\t"
1366 "vle32.v v2, (%[bsums])\n\t"
1367 "vnsrl.wi v0, v2, 0\n\t"
1368 "vnsrl.wi v1, v2, 16\n\t"
1369 "vadd.vv v2, v0, v1\n\t"
1370 "vle8.v v3, (%[mins])\n\t"
1371 "vzext.vf2 v4, v3\n\t"
1372 "vwmul.vv v6, v4, v2\n\t"
1373 "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1374 "vredsum.vs v0, v6, v16\n\t"
1375 "vredsum.vs v0, v7, v0\n\t"
1376 "vfcvt.f.x.v v0, v0\n\t"
1377 "vfmv.f.s %[ftmp], v0\n\t"
1378 "vsetivli zero, 16, e8, m1, ta, ma\n\t"
1379 "vle8.v v0, (%[xs])\n\t"
1380 "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
1381 "addi %[q40], %[xs], 64\n\t"
1382 "addi %[q41], %[xs], 16\n\t"
1383 "addi %[q42], %[xs], 32\n\t"
1384 "addi %[q43], %[xs], 48\n\t"
1385 "addi %[q80], %[ys], 64\n\t"
1386 "vle8.v v1, (%[q41])\n\t"
1387 "vle8.v v2, (%[q42])\n\t"
1388 "addi %[q81], %[ys], 16\n\t"
1389 "addi %[q41], %[q41], 64\n\t"
1390 "addi %[q82], %[ys], 32\n\t"
1391 "vle8.v v3, (%[q43])\n\t"
1392 "vle8.v v8, (%[ys])\n\t"
1393 "addi %[q42], %[q42], 64\n\t"
1394 "addi %[q83], %[ys], 48\n\t"
1395 "addi %[q43], %[q43], 64\n\t"
1396 "vsrl.vi v4, v0, 4\n\t"
1397 "vle8.v v9, (%[q81])\n\t"
1398 "vle8.v v10, (%[q82])\n\t"
1399 "vand.vi v0, v0, 0xF\n\t"
1400 "addi %[q81], %[q81], 64\n\t"
1401 "vsrl.vi v5, v1, 4\n\t"
1402 "addi %[q82], %[q82], 64\n\t"
1403 "vle8.v v11, (%[q83])\n\t"
1404 "vle8.v v12, (%[q80])\n\t"
1405 "vand.vi v1, v1, 0xF\n\t"
1406 "addi %[q83], %[q83], 64\n\t"
1407 "vsrl.vi v6, v2, 4\n\t"
1408 "addi %[q80], %[q80], 64\n\t"
1409 "vle8.v v13, (%[q81])\n\t"
1410 "vle8.v v14, (%[q82])\n\t"
1411 "vand.vi v2, v2, 0xF\n\t"
1412 "addi %[q81], %[q81], 64\n\t"
1413 "vsrl.vi v7, v3, 4\n\t"
1414 "addi %[q82], %[q82], 64\n\t"
1415 "vwmul.vv v16, v0, v8\n\t"
1416 "vle8.v v15, (%[q83])\n\t"
1417 "vle8.v v0, (%[q40])\n\t"
1418 "vand.vi v3, v3, 0xF\n\t"
1419 "addi %[q83], %[q83], 64\n\t"
1420 "vwmul.vv v24, v2, v12\n\t"
1421 "vwmul.vv v20, v4, v10\n\t"
1422 "vwmul.vv v28, v6, v14\n\t"
1423 "vwmacc.vv v16, v1, v9\n\t"
1424 "vle8.v v1, (%[q41])\n\t"
1425 "vle8.v v2, (%[q42])\n\t"
1426 "vwmacc.vv v24, v3, v13\n\t"
1427 "vwmacc.vv v20, v5, v11\n\t"
1428 "vwmacc.vv v28, v7, v15\n\t"
1429 "addi %[q40], %[q80], 64\n\t"
1430 "addi %[q41], %[q81], 64\n\t"
1431 "vle8.v v3, (%[q43])\n\t"
1432 "vle8.v v8, (%[q80])\n\t"
1433 "addi %[q42], %[q82], 64\n\t"
1434 "addi %[q43], %[q83], 64\n\t"
1435 "vsrl.vi v4, v0, 4\n\t"
1436 "vle8.v v9, (%[q81])\n\t"
1437 "vle8.v v10, (%[q82])\n\t"
1438 "vand.vi v0, v0, 0xF\n\t"
1439 "vsrl.vi v5, v1, 4\n\t"
1440 "vsrl.vi v7, v3, 4\n\t"
1441 "vand.vi v3, v3, 0xF\n\t"
1442 "vle8.v v11, (%[q83])\n\t"
1443 "vle8.v v12, (%[q40])\n\t"
1444 "vand.vi v1, v1, 0xF\n\t"
1445 "vsrl.vi v6, v2, 4\n\t"
1446 "vand.vi v2, v2, 0xF\n\t"
1447 "vwmul.vv v18, v0, v8\n\t"
1448 "vle8.v v13, (%[q41])\n\t"
1449 "vle8.v v14, (%[q42])\n\t"
1450 "vwmul.vv v26, v2, v12\n\t"
1451 "vwmul.vv v22, v4, v10\n\t"
1452 "vwmul.vv v30, v6, v14\n\t"
1453 "vwmacc.vv v18, v1, v9\n\t"
1454 "vle8.v v15, (%[q43])\n\t"
1455 "vwmacc.vv v26, v3, v13\n\t"
1456 "vwmacc.vv v22, v5, v11\n\t"
1457 "vwmacc.vv v30, v7, v15\n\t"
1458 "vmv.v.x v0, zero\n\t"
1459 "vsetivli zero, 16, e16, m2, ta, ma\n\t"
1460 "vwredsum.vs v4, v16, v0\n\t"
1461 "lbu %[s0], 0(%[scale])\n\t"
1462 "vwredsum.vs v5, v20, v0\n\t"
1463 "lbu %[s1], 1(%[scale])\n\t"
1464 "vwredsum.vs v6, v24, v0\n\t"
1465 "lbu %[s2], 2(%[scale])\n\t"
1466 "vwredsum.vs v7, v28, v0\n\t"
1467 "lbu %[s3], 3(%[scale])\n\t"
1468 "vwredsum.vs v8, v18, v0\n\t"
1469 "lbu %[q40], 4(%[scale])\n\t"
1470 "vwredsum.vs v9, v22, v0\n\t"
1471 "lbu %[q41], 5(%[scale])\n\t"
1472 "vwredsum.vs v10, v26, v0\n\t"
1473 "lbu %[q42], 6(%[scale])\n\t"
1474 "vwredsum.vs v11, v30, v0\n\t"
1475 "lbu %[q43], 7(%[scale])\n\t"
1476 "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1477 "vmul.vx v0, v4, %[s0]\n\t"
1478 "vmul.vx v1, v8, %[q40]\n\t"
1479 "vmacc.vx v0, %[s1], v5\n\t"
1480 "vmacc.vx v1, %[q41], v9\n\t"
1481 "vmacc.vx v0, %[s2], v6\n\t"
1482 "vmacc.vx v1, %[q42], v10\n\t"
1483 "vmacc.vx v0, %[s3], v7\n\t"
1484 "vmacc.vx v1, %[q43], v11\n\t"
1485 "vfcvt.f.x.v v0, v0\n\t"
1486 "vfcvt.f.x.v v1, v1\n\t"
1487 "vfmv.f.s %[ft2], v0\n\t"
1488 "vfmv.f.s %[ftmp], v1\n\t"
1489 "fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
1490 "fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
1491 : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
1492 , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
1493 , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
1494 , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
1495 : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
1496 , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1497 , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
1498 , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1499 : "memory"
1500 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1501 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1502 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1503 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1504 );
1505 }
1506 break;
1507 default:
1508 assert(false && "Unsupported vector length");
1509 break;
1510 }
1511
1512 *s = sumf;
1513
1514#else
1515
1516 UNUSED(x);
1517 UNUSED(y);
1518 UNUSED(kmask1);
1519 UNUSED(kmask2);
1520 UNUSED(kmask3);
1521 UNUSED(nb);
1522 UNUSED(utmp);
1523
1524 ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1525#endif
1526}
1527
1528void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1529 assert(n % QK_K == 0);
1530 assert(nrc == 1);
1531 UNUSED(nrc);
1532 UNUSED(bx);
1533 UNUSED(by);
1534 UNUSED(bs);
1535
1536 const block_q5_K * GGML_RESTRICT x = vx;
1537 const block_q8_K * GGML_RESTRICT y = vy;
1538
1539 const int nb = n / QK_K;
1540
1541 static const uint32_t kmask1 = 0x3f3f3f3f;
1542 static const uint32_t kmask2 = 0x0f0f0f0f;
1543 static const uint32_t kmask3 = 0x03030303;
1544
1545 uint32_t utmp[4];
1546
1547#if defined __riscv_v
1548
1549 const uint8_t * scales = (const uint8_t*)&utmp[0];
1550 const uint8_t * mins = (const uint8_t*)&utmp[2];
1551
1552 float sumf = 0;
1553 float sums = 0.0;
1554
1555 size_t vl;
1556
1557 for (int i = 0; i < nb; ++i) {
1558
1559 vl = 8;
1560
1561 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1562 const uint8_t * GGML_RESTRICT hm = x[i].qh;
1563 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1564
1565 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1566 const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
1567
1568 vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
1569 vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
1570 vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
1571
1572 memcpy(utmp, x[i].scales, 12);
1573 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1574 const uint32_t uaux = utmp[1] & kmask1;
1575 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1576 utmp[2] = uaux;
1577 utmp[0] &= kmask1;
1578
1579 vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
1580 vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
1581 vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
1582
1583 vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1584 sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1585
1586 vl = 32;
1587 int32_t aux32 = 0;
1588 int is = 0;
1589
1590 uint8_t m = 1;
1591 vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1592 vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
1593
1594 for (int j = 0; j < QK_K/64; ++j) {
1595 // load Q5 and Q8
1596 vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
1597 vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
1598 vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
1599
1600 // compute mask for addition
1601 vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
1602 vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
1603 vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
1604 vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
1605 m <<= 1;
1606
1607 vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
1608 vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
1609 vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
1610 vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
1611 m <<= 1;
1612
1613 vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
1614 vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
1615
1616 vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
1617 vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
1618
1619 vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
1620 vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
1621
1622 aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
1623 q5 += 32; q8 += 64;
1624
1625 }
1626
1627 sums += aux32 * d;
1628
1629 }
1630
1631 *s = sumf+sums;
1632
1633#else
1634
1635 UNUSED(x);
1636 UNUSED(y);
1637 UNUSED(kmask1);
1638 UNUSED(kmask2);
1639 UNUSED(kmask3);
1640 UNUSED(nb);
1641 UNUSED(utmp);
1642
1643 ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1644#endif
1645}
1646
1647void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1648 assert(n % QK_K == 0);
1649 assert(nrc == 1);
1650 UNUSED(nrc);
1651 UNUSED(bx);
1652 UNUSED(by);
1653 UNUSED(bs);
1654
1655 const block_q6_K * GGML_RESTRICT x = vx;
1656 const block_q8_K * GGML_RESTRICT y = vy;
1657
1658 const int nb = n / QK_K;
1659
1660#if defined __riscv_xtheadvector
1661
1662 float sumf = 0;
1663
1664 for (int i = 0; i < nb; ++i) {
1665
1666 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1667
1668 const uint8_t * restrict q6 = x[i].ql;
1669 const uint8_t * restrict qh = x[i].qh;
1670 const int8_t * restrict q8 = y[i].qs;
1671
1672 const int8_t * restrict scale = x[i].scales;
1673
1674 int sum_t = 0;
1675 int t0;
1676
1677 for (int j = 0; j < QK_K/128; ++j) {
1678 __asm__ __volatile__(
1679 "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
1680 "th.vlb.v v4, (%[qh])\n\t"
1681 "th.vsll.vi v0, v4, 4\n\t"
1682 "th.vsll.vi v2, v4, 2\n\t"
1683 "th.vsrl.vi v6, v4, 2\n\t"
1684 "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1685 "th.vlb.v v8, (%[q6])\n\t"
1686 "th.vsrl.vi v12, v8, 4\n\t"
1687 "th.vand.vi v8, v8, 0xF\n\t"
1688 "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
1689 "th.vand.vx v0, v0, %[mask]\n\t"
1690 "th.vor.vv v8, v8, v0\n\t"
1691 "th.vlb.v v0, (%[q8])\n\t"
1692 "th.vsub.vx v8, v8, %[vl32]\n\t"
1693 "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1694 "th.vwmul.vv v16, v0, v8\n\t"
1695 "th.vwmul.vv v24, v4, v12\n\t"
1696 "li %[t0], 16\n\t"
1697 "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
1698 "th.vmv.v.x v0, zero\n\t"
1699 "th.vwredsum.vs v10, v16, v0\n\t"
1700 "th.vwredsum.vs v9, v18, v0\n\t"
1701 "th.vwredsum.vs v8, v20, v0\n\t"
1702 "th.vwredsum.vs v7, v22, v0\n\t"
1703 "th.vwredsum.vs v11, v24, v0\n\t"
1704 "th.vwredsum.vs v12, v26, v0\n\t"
1705 "th.vwredsum.vs v13, v28, v0\n\t"
1706 "th.vwredsum.vs v14, v30, v0\n\t"
1707 "li %[t0], 4\n\t"
1708 "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
1709 "th.vslideup.vi v10, v9, 1\n\t"
1710 "th.vslideup.vi v8, v7, 1\n\t"
1711 "th.vslideup.vi v11, v12, 1\n\t"
1712 "th.vslideup.vi v13, v14, 1\n\t"
1713 "th.vslideup.vi v10, v8, 2\n\t"
1714 "th.vslideup.vi v11, v13, 2\n\t"
1715 "li %[t0], 8\n\t"
1716 "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
1717 "th.vlb.v v4, (%[scale])\n\t"
1718 "th.vmul.vv v2, v4, v10\n\t"
1719 "th.vredsum.vs v0, v2, v0\n\t"
1720 "th.vmv.x.s %[t0], v0\n\t"
1721 "add %[sumi], %[sumi], %[t0]"
1722 : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
1723 : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
1724 , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1725 , [mask] "r" (0x30)
1726 : "memory"
1727 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1728 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1729 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1730 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1731 );
1732 q6 += 64; qh += 32; q8 += 128; scale += 8;
1733 }
1734
1735 sumf += d * sum_t;
1736
1737 }
1738
1739 *s = sumf;
1740
1741#elif defined __riscv_v
1742
1743 float sumf = 0;
1744 const int vector_length = __riscv_vlenb() * 8;
1745
1746 switch (vector_length) {
1747 case 256:
1748 for (int i = 0; i < nb; ++i) {
1749
1750 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1751
1752 const uint8_t * GGML_RESTRICT q6 = x[i].ql;
1753 const uint8_t * GGML_RESTRICT qh = x[i].qh;
1754 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1755
1756 const int8_t * GGML_RESTRICT scale = x[i].scales;
1757
1758 size_t vl;
1759
1760 vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1761
1762 int sum_t = 0;
1763 int is = 0;
1764
1765 for (int j = 0; j < QK_K/128; ++j) {
1766
1767 vl = 32;
1768
1769 // load qh
1770 vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
1771
1772 // load Q6
1773 vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
1774 vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
1775
1776 vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
1777 vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
1778 vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
1779 vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
1780
1781 vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
1782 vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
1783 vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
1784 vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
1785
1786 vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
1787 vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
1788 vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
1789 vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
1790
1791 vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
1792 vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
1793 vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
1794 vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
1795
1796 // load Q8 and take product
1797 vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
1798 vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
1799 vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
1800 vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
1801
1802 vl = 16;
1803
1804 vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
1805 vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
1806 vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
1807 vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
1808 vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
1809 vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
1810 vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
1811 vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
1812
1813 vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
1814 vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
1815 vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
1816 vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
1817
1818 sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1819
1820 q6 += 64; qh += 32; q8 += 128; is=8;
1821
1822 }
1823
1824 sumf += d * sum_t;
1825
1826 }
1827 break;
1828 case 128:
1829 for (int i = 0; i < nb; ++i) {
1830
1831 __builtin_prefetch(&x[i + 1].d, 0, 1);
1832
1833 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1834
1835 const uint8_t * restrict q6 = x[i].ql;
1836 const uint8_t * restrict qh = x[i].qh;
1837 const int8_t * restrict q8 = y[i].qs;
1838
1839 const int8_t * restrict scale = x[i].scales;
1840
1841 int q6h;
1842 float ftmp;
1843
1844 for (int j = 0; j < QK_K/128; ++j) {
1845 __asm__ __volatile__(
1846 "addi %[q6h], %[q6], 32\n\t"
1847 "ld t0, 0(%[scale])\n\t"
1848 "addi %[scale], %[scale], 8\n\t"
1849 "slli t6, t0, 1 * 8\n\t"
1850 "lb zero, 0(%[q6])\n\t"
1851 "slli t5, t0, 2 * 8\n\t"
1852 "slli t4, t0, 3 * 8\n\t"
1853 "lb zero, 0(%[q6h])\n\t"
1854 "slli t3, t0, 4 * 8\n\t"
1855 "slli t2, t0, 5 * 8\n\t"
1856 "lb zero, 0(%[qh])\n\t"
1857 "lb zero, 31(%[q6h])\n\t"
1858 "slli t1, t0, 6 * 8\n\t"
1859 "srai a7, t0, 56\n\t"
1860 "vsetvli zero, %[vl32], e8, m2\n\t"
1861 "vle8.v v8, (%[q6])\n\t"
1862 "srai t6, t6, 56\n\t"
1863 "srai t5, t5, 56\n\t"
1864 "srai t4, t4, 56\n\t"
1865 "srai t3, t3, 56\n\t"
1866 "vle8.v v10, (%[q6h])\n\t"
1867 "addi %[q6], %[q6], 64\n\t"
1868 "slli t0, t0, 7 * 8\n\t"
1869 "srai t2, t2, 56\n\t"
1870 "srai t1, t1, 56\n\t"
1871 "srai t0, t0, 56\n\t"
1872 "vle8.v v4, (%[qh])\n\t"
1873 "vsrl.vi v12, v8, 4\n\t"
1874 "vsrl.vi v14, v10, 4\n\t"
1875 "lb zero, 0(%[q8])\n\t"
1876 "vand.vi v8, v8, 0xF\n\t"
1877 "vand.vi v10, v10, 0xF\n\t"
1878 "lb zero, 32(%[q8])\n\t"
1879 "vsll.vi v0, v4, 4\n\t"
1880 "vsll.vi v2, v4, 2\n\t"
1881 "lb zero, 64(%[q8])\n\t"
1882 "vsrl.vi v6, v4, 2\n\t"
1883 "vand.vx v0, v0, %[mask]\n\t"
1884 "lb zero, 96(%[q8])\n\t"
1885 "vand.vx v2, v2, %[mask]\n\t"
1886 "vand.vx v4, v4, %[mask]\n\t"
1887 "vand.vx v6, v6, %[mask]\n\t"
1888 "vor.vv v8, v8, v0\n\t"
1889 "lb zero, 127(%[q8])\n\t"
1890 "vor.vv v10, v10, v2\n\t"
1891 "vor.vv v12, v12, v4\n\t"
1892 "vor.vv v14, v14, v6\n\t"
1893 "vsetvli zero, %[vl128], e8, m8\n\t"
1894 "vle8.v v0, (%[q8])\n\t"
1895 "vsub.vx v8, v8, %[vl32]\n\t"
1896 "vsetvli zero, %[vl64], e8, m4\n\t"
1897 "vwmul.vv v16, v0, v8\n\t"
1898 "vwmul.vv v24, v4, v12\n\t"
1899 "vsetivli zero, 16, e16, m2\n\t"
1900 "vmv.v.x v0, zero\n\t"
1901 "vwredsum.vs v10, v16, v0\n\t"
1902 "vwredsum.vs v9, v18, v0\n\t"
1903 "vwredsum.vs v8, v20, v0\n\t"
1904 "vwredsum.vs v7, v22, v0\n\t"
1905 "vwredsum.vs v11, v24, v0\n\t"
1906 "vwredsum.vs v12, v26, v0\n\t"
1907 "vwredsum.vs v13, v28, v0\n\t"
1908 "vwredsum.vs v14, v30, v0\n\t"
1909 "vsetivli zero, 4, e32, m1\n\t"
1910 "vmul.vx v0, v10, t0\n\t"
1911 "vmul.vx v1, v9, t1\n\t"
1912 "vmacc.vx v0, t2, v8\n\t"
1913 "vmacc.vx v1, t3, v7\n\t"
1914 "vmacc.vx v0, t4, v11\n\t"
1915 "vmacc.vx v1, t5, v12\n\t"
1916 "vmacc.vx v0, t6, v13\n\t"
1917 "vmacc.vx v1, a7, v14\n\t"
1918 "vadd.vv v0, v0, v1\n\t"
1919 "vfcvt.f.x.v v0, v0\n\t"
1920 "vfmv.f.s %[ftmp], v0\n\t"
1921 "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
1922 : [q6] "+&r" (q6), [q6h] "=&r" (q6h)
1923 , [scale] "+&r" (scale)
1924 , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
1925 : [qh] "r" (qh), [q8] "r" (q8)
1926 , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1927 , [mask] "r" (0x30), [d] "f" (d)
1928 : "memory"
1929 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1930 , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1931 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1932 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1933 , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
1934 , "a6", "a5", "a4", "a3"
1935 );
1936 qh += 32; q8 += 128;
1937 }
1938 }
1939 break;
1940 default:
1941 assert(false && "Unsupported vector length");
1942 break;
1943 }
1944
1945 *s = sumf;
1946
1947#else
1948
1949 UNUSED(x);
1950 UNUSED(y);
1951 UNUSED(nb);
1952
1953 ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1954#endif
1955}
1956