summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp')
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp3196
1 files changed, 3196 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
new file mode 100644
index 0000000..cbbb6cd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
@@ -0,0 +1,3196 @@
1#include "ggml.h"
2#include "ime_kernels.h"
3
4#include <algorithm>
5#include <cmath>
6
7// clang-format off
8#if defined(__GNUC__)
9#pragma GCC diagnostic ignored "-Woverlength-strings"
10#pragma GCC diagnostic ignored "-Wcast-qual"
11#pragma GCC diagnostic ignored "-Wunused-parameter"
12#endif
13// clang-format on
14namespace sqnbitgemm_spacemit_ime {
15
16#define QUANTIZEM4ROW_KERNEL \
17 "vmv.s.x v16, zero \n\t" \
18 "vfabs.v v8, v0 \n\t" \
19 "vfredmax.vs v16, v8, v16 \n\t" \
20 "vfmv.f.s f10, v16 \n\t" \
21 "fmul.s f10, f10, %[RMAXREC] \n\t" \
22 "fsw f10, (a1) \n\t" \
23 "fdiv.s f11, %[FONE], f10 \n\t" \
24 "vfmul.vf v16, v0, f11 \n\t" \
25 "vfcvt.x.f.v v16, v16 \n\t" \
26 "vsetvli t0, zero, e16, mf2 \n\t" \
27 "vnclip.wx v16, v16, zero \n\t" \
28 "vnclip.wx v17, v17, zero \n\t" \
29 "vnclip.wx v18, v18, zero \n\t" \
30 "vnclip.wx v19, v19, zero \n\t" \
31 "vnclip.wx v20, v20, zero \n\t" \
32 "vnclip.wx v21, v21, zero \n\t" \
33 "vnclip.wx v22, v22, zero \n\t" \
34 "vnclip.wx v23, v23, zero \n\t" \
35 "vsetvli t0, zero, e8, mf4 \n\t" \
36 "vnclip.wx v24, v16, zero \n\t" \
37 "vnclip.wx v25, v17, zero \n\t" \
38 "vnclip.wx v26, v18, zero \n\t" \
39 "vnclip.wx v27, v19, zero \n\t" \
40 "vnclip.wx v28, v20, zero \n\t" \
41 "vnclip.wx v29, v21, zero \n\t" \
42 "vnclip.wx v30, v22, zero \n\t" \
43 "vnclip.wx v31, v23, zero \n\t"
44
45#define QUANTIZEM4ROW_STORE \
46 "addi t1, %[BlkLen], 0 \n\t" \
47 "vsetvli t0, t1, e8, mf4 \n\t" \
48 "vse8.v v24, (s1) \n\t" \
49 "addi s1, s1, 32 \n\t" \
50 "sub t1, t1, t0 \n\t" \
51 "vsetvli t0, t1, e8, mf4 \n\t" \
52 "vse8.v v25, (s1) \n\t" \
53 "addi s1, s1, 32 \n\t" \
54 "sub t1, t1, t0 \n\t" \
55 "vsetvli t0, t1, e8, mf4 \n\t" \
56 "vse8.v v26, (s1) \n\t" \
57 "addi s1, s1, 32 \n\t" \
58 "sub t1, t1, t0 \n\t" \
59 "vsetvli t0, t1, e8, mf4 \n\t" \
60 "vse8.v v27, (s1) \n\t" \
61 "addi s1, s1, 32 \n\t" \
62 "sub t1, t1, t0 \n\t" \
63 "vsetvli t0, t1, e8, mf4 \n\t" \
64 "vse8.v v28, (s1) \n\t" \
65 "addi s1, s1, 32 \n\t" \
66 "sub t1, t1, t0 \n\t" \
67 "vsetvli t0, t1, e8, mf4 \n\t" \
68 "vse8.v v29, (s1) \n\t" \
69 "addi s1, s1, 32 \n\t" \
70 "sub t1, t1, t0 \n\t" \
71 "vsetvli t0, t1, e8, mf4 \n\t" \
72 "vse8.v v30, (s1) \n\t" \
73 "addi s1, s1, 32 \n\t" \
74 "sub t1, t1, t0 \n\t" \
75 "vsetvli t0, t1, e8, mf4 \n\t" \
76 "vse8.v v31, (s1) \n\t"
77
78namespace ime1 {
79void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
80 constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
81 const float fone = 1.0f;
82
83 if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
84 for (size_t row_index = 0; row_index < 4; ++row_index) {
85 const float * SRC = A + row_index * CountK;
86 std::byte * DST = QuantA + row_index * sizeof(float);
87
88 const size_t offset = (4 - row_index) * 4 + row_index * 8;
89 const size_t stride = 4 * (sizeof(float) + BlkLen);
90 __asm__ volatile(
91 "vsetvli t0, zero, e32, m8 \n\t"
92 "addi t2, %[CountK], 0 \n\t"
93 "addi a1, %[DST], 0 \n\t"
94 "blt t2, %[BlkLen], TAIL%= \n\t"
95
96 "LOOP%=: \n\t"
97 "vsetvli t0, %[BlkLen], e32, m8 \n\t"
98 "vle32.v v0, (%[SRC]) \n\t"
99 "sub t2, t2, t0 \n\t"
100 "slli t1, t0, 2 \n\t"
101 "add %[SRC], %[SRC], t1 \n\t"
102 "add s1, a1, %[OFFSET] \n\t"
103
104 QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
105
106 "add a1, a1, %[STRIDE] \n\t"
107 "bge t2, %[BlkLen], LOOP%= \n\t"
108
109 "TAIL%=: \n\t"
110 "blez t2, QUIT%= \n\t"
111 "vsetvli t0, zero, e32, m8 \n\t"
112 "vxor.vv v16, v16, v16 \n\t"
113 "vxor.vv v24, v24, v24 \n\t"
114 "vsetvli t0, t2, e32, m8 \n\t"
115 "vle32.v v0, (%[SRC]) \n\t"
116 "add s1, a1, %[OFFSET] \n\t"
117
118 QUANTIZEM4ROW_KERNEL
119
120 "addi t3, %[BlkLen], 0 \n\t"
121 "addi s2, s1, 0 \n\t"
122 "vsetvli t0, zero, e8, mf4 \n\t"
123 "vxor.vv v8, v8, v8 \n\t"
124 "SET_ZERO%=: \n\t"
125 "vse8.v v8, (s2) \n\t"
126 "addi s2, s2, 32 \n\t"
127 "addi t3, t3, -8 \n\t"
128 "bnez t3, SET_ZERO%= \n\t"
129
130 QUANTIZEM4ROW_STORE
131
132 "QUIT%=: \n\t"
133 : [SRC] "+r"(SRC)
134 : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
135 [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
136 : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
137 }
138 } else if (BlkLen == 128) {
139 for (size_t row_index = 0; row_index < 4; ++row_index) {
140 const float * SRC = A + row_index * CountK;
141 std::byte * DST = QuantA + row_index * sizeof(float);
142
143 const size_t offset = (4 - row_index) * 4 + row_index * 8;
144 const size_t stride = 4 * (sizeof(float) + BlkLen);
145 __asm__ volatile(
146 "vsetvli t0, zero, e32, m8 \n\t"
147 "li t6, 32 \n\t"
148 "addi t2, %[CountK], 0 \n\t"
149 "addi a1, %[DST], 0 \n\t"
150 "add s1, a1, %[OFFSET] \n\t"
151 "blt t2, %[BlkLen], TAIL%= \n\t"
152
153 "LOOP%=: \n\t"
154 "vsetvli t0, zero, e32, m8 \n\t"
155 "vle32.v v0, (%[SRC]) \n\t"
156 "addi %[SRC], %[SRC], 256 \n\t"
157 "vle32.v v8, (%[SRC]) \n\t"
158 "addi %[SRC], %[SRC], 256 \n\t"
159 "addi t2, t2, -128 \n\t"
160
161 "QUANTIZE%=: \n\t"
162 "add s1, a1, %[OFFSET] \n\t"
163 "vfabs.v v16, v0 \n\t"
164 "vfabs.v v24, v8 \n\t"
165 "vfmax.vv v16, v24, v16 \n\t"
166 "vfredmax.vs v24, v16, v24 \n\t"
167 "vfmv.f.s f10, v24 \n\t"
168 "fmul.s f10, f10, %[RMAXREC] \n\t"
169 "fsw f10, (a1) \n\t"
170 "fdiv.s f11, %[FONE], f10 \n\t"
171 "vfmul.vf v16, v0, f11 \n\t"
172 "vfmul.vf v24, v8, f11 \n\t"
173 "vfcvt.x.f.v v16, v16 \n\t"
174 "vfcvt.x.f.v v24, v24 \n\t"
175 "vsetvli t0, zero, e16, m4 \n\t"
176 "vnclip.wx v16, v16, zero \n\t"
177 "vnclip.wx v20, v24, zero \n\t"
178 "vsetvli t0, zero, e8, m4 \n\t"
179 "vnclip.wx v16, v16, zero \n\t"
180 "vsetvli t0, zero, e64, m4 \n\t"
181 "vsse64.v v16, (s1), t6 \n\t"
182 "add a1, a1, %[STRIDE] \n\t"
183 "bge t2, %[BlkLen], LOOP%= \n\t"
184
185 "TAIL%=: \n\t"
186 "blez t2, QUIT%= \n\t"
187 "vsetvli t0, zero, e32, m8 \n\t"
188 "vxor.vv v0, v0, v0 \n\t"
189 "vxor.vv v8, v8, v8 \n\t"
190 "vxor.vv v16, v16, v16 \n\t"
191 "vxor.vv v24, v24, v24 \n\t"
192 "vsetvli t0, t2, e32, m8 \n\t"
193 "sub t2, t2, t0 \n\t"
194 "vle32.v v0, (%[SRC]) \n\t"
195 "addi %[SRC], %[SRC], 256 \n\t"
196 "vsetvli t0, t2, e32, m8 \n\t"
197 "vle32.v v8, (%[SRC]) \n\t"
198 "sub t2, t2, t2 \n\t"
199 "vsetvli t0, zero, e32, m8 \n\t"
200 "jal x0, QUANTIZE%= \n\t"
201
202 "QUIT%=: \n\t"
203 : [SRC] "+r"(SRC)
204 : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
205 [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
206 : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
207 }
208 } else if (BlkLen == 256) {
209 for (size_t row_index = 0; row_index < 4; ++row_index) {
210 const float * SRC = A + row_index * CountK;
211 std::byte * DST = QuantA + row_index * sizeof(float);
212 const size_t offset = (4 - row_index) * 4 + row_index * 8;
213 const size_t stride = 4 * (sizeof(float) + BlkLen);
214 __asm__ volatile(
215 "vsetvli t0, zero, e32, m8 \n\t"
216 "li t6, 32 \n\t"
217 "addi t2, %[CountK], 0 \n\t"
218 "addi a1, %[DST], 0 \n\t"
219 "add s1, a1, %[OFFSET] \n\t"
220 "blt t2, %[BlkLen], TAIL%= \n\t"
221
222 "LOOP%=: \n\t"
223 "vsetvli t0, zero, e32, m8 \n\t"
224 "vle32.v v0, (%[SRC]) \n\t"
225 "addi %[SRC], %[SRC], 256 \n\t"
226 "vle32.v v8, (%[SRC]) \n\t"
227 "addi %[SRC], %[SRC], 256 \n\t"
228 "vle32.v v16, (%[SRC]) \n\t"
229 "addi %[SRC], %[SRC], 256 \n\t"
230 "vle32.v v24, (%[SRC]) \n\t"
231 "addi %[SRC], %[SRC], -768 \n\t"
232 "addi t2, t2, -256 \n\t"
233 "vfabs.v v0, v0 \n\t"
234 "vfabs.v v8, v8 \n\t"
235 "vfabs.v v16, v16 \n\t"
236 "vfabs.v v24, v24 \n\t"
237 "vfmax.vv v8, v0, v8 \n\t"
238 "vfmax.vv v24, v24, v16 \n\t"
239 "vfmax.vv v8, v8, v24 \n\t"
240 "vfredmax.vs v24, v8, v24 \n\t"
241 "vfmv.f.s f10, v24 \n\t"
242 "vle32.v v0, (%[SRC]) \n\t"
243 "addi %[SRC], %[SRC], 256 \n\t"
244 "vle32.v v8, (%[SRC]) \n\t"
245 "addi %[SRC], %[SRC], 256 \n\t"
246 "vle32.v v16, (%[SRC]) \n\t"
247 "addi %[SRC], %[SRC], 256 \n\t"
248 "vle32.v v24, (%[SRC]) \n\t"
249 "addi %[SRC], %[SRC], 256 \n\t"
250
251 "QUANTIZE%=: \n\t"
252 "add s1, a1, %[OFFSET] \n\t"
253 "fmul.s f10, f10, %[RMAXREC] \n\t"
254 "fsw f10, (a1) \n\t"
255 "fdiv.s f11, %[FONE], f10 \n\t"
256 "vfmul.vf v0, v0, f11 \n\t"
257 "vfmul.vf v8, v8, f11 \n\t"
258 "vfmul.vf v16, v16, f11 \n\t"
259 "vfmul.vf v24, v24, f11 \n\t"
260 "vfcvt.x.f.v v0, v0 \n\t"
261 "vfcvt.x.f.v v8, v8 \n\t"
262 "vfcvt.x.f.v v16, v16 \n\t"
263 "vfcvt.x.f.v v24, v24 \n\t"
264 "vsetvli t0, zero, e16, m4 \n\t"
265 "vnclip.wx v0, v0, zero \n\t"
266 "vnclip.wx v4, v8, zero \n\t"
267 "vnclip.wx v8, v16, zero \n\t"
268 "vnclip.wx v12, v24, zero \n\t"
269 "vsetvli t0, zero, e8, m4 \n\t"
270 "vnclip.wx v0, v0, zero \n\t"
271 "vnclip.wx v4, v8, zero \n\t"
272 "vsetvli t0, zero, e64, m8 \n\t"
273 "vsse64.v v0, (s1), t6 \n\t"
274 "add a1, a1, %[STRIDE] \n\t"
275 "bge t2, %[BlkLen], LOOP%= \n\t"
276
277 "TAIL%=: \n\t"
278 "blez t2, QUIT%= \n\t"
279 "vsetvli t0, zero, e32, m8 \n\t"
280 "vxor.vv v0, v0, v0 \n\t"
281 "vxor.vv v8, v8, v8 \n\t"
282 "vxor.vv v16, v16, v16 \n\t"
283 "vxor.vv v24, v24, v24 \n\t"
284 "addi t1, t2, 0 \n\t"
285 "vsetvli t0, t1, e32, m8 \n\t"
286 "sub t1, t1, t0 \n\t"
287 "vle32.v v0, (%[SRC]) \n\t"
288 "addi %[SRC], %[SRC], 256 \n\t"
289 "vsetvli t0, t1, e32, m8 \n\t"
290 "sub t1, t1, t0 \n\t"
291 "vle32.v v8, (%[SRC]) \n\t"
292 "addi %[SRC], %[SRC], 256 \n\t"
293 "vsetvli t0, t1, e32, m8 \n\t"
294 "sub t1, t1, t0 \n\t"
295 "vle32.v v16, (%[SRC]) \n\t"
296 "addi %[SRC], %[SRC], 256 \n\t"
297 "vsetvli t0, t1, e32, m8 \n\t"
298 "vle32.v v24, (%[SRC]) \n\t"
299 "addi %[SRC], %[SRC], -768 \n\t"
300 "vsetvli t0, zero, e32, m8 \n\t"
301 "vfabs.v v0, v0 \n\t"
302 "vfabs.v v8, v8 \n\t"
303 "vfabs.v v16, v16 \n\t"
304 "vfabs.v v24, v24 \n\t"
305 "vfmax.vv v8, v0, v8 \n\t"
306 "vfmax.vv v24, v16, v24 \n\t"
307 "vfmax.vv v8, v8, v24 \n\t"
308 "vfredmax.vs v24, v8, v24 \n\t"
309 "vfmv.f.s f10, v24 \n\t"
310 "add s1, a1, %[OFFSET] \n\t"
311 "fmul.s f10, f10, %[RMAXREC] \n\t"
312 "fsw f10, (a1) \n\t"
313 "fdiv.s f11, %[FONE], f10 \n\t"
314 "vsetvli t0, zero, e64, m8 \n\t"
315 "vxor.vv v0, v0, v0 \n\t"
316 "vsse64.v v0, (s1), t6 \n\t"
317
318 "TAIL_LOOP%=: \n\t"
319 "vsetvli t0, zero, e32, m4 \n\t"
320 "vxor.vv v0, v0, v0 \n\t"
321 "vsetvli t0, t2, e32, m1 \n\t"
322 "sub t2, t2, t0 \n\t"
323 "vle32.v v0, (%[SRC]) \n\t"
324 "addi %[SRC], %[SRC], 32 \n\t"
325 "vfmul.vf v1, v0, f11 \n\t"
326 "vfcvt.x.f.v v2, v1 \n\t"
327 "vsetvli t0, zero, e16, mf2 \n\t"
328 "vnclip.wx v3, v2, zero \n\t"
329 "vsetvli t0, zero, e8, mf4 \n\t"
330 "vnclip.wx v3, v3, zero \n\t"
331 "vse8.v v3, (s1) \n\t"
332 "addi s1, s1, 32 \n\t"
333 "bnez t2, TAIL_LOOP%= \n\t"
334
335 "QUIT%=: \n\t"
336 : [SRC] "+r"(SRC)
337 : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
338 [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
339 : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
340 }
341 }
342}
343
344void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
345 const float * SRC = A;
346 std::byte * DST = QuantA;
347 constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
348 const float fone = 1.0f;
349 std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
350 size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
351
352 if (CountK <= BlkLen) {
353 float max_abs_A = 0.0f;
354 for (size_t k = 0; k < CountK; k++) {
355 max_abs_A = std::max(max_abs_A, fabsf(A[k]));
356 }
357 float scale_A = max_abs_A * range_max_reciprocal;
358
359 ((float *) QuantA)[0] = scale_A;
360
361 auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
362
363 for (size_t k = 0; k < CountK; k++) {
364 QuantAData_offset[k] =
365 (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
366 (float) std::numeric_limits<int8_t>::max());
367 }
368 for (size_t k = CountK; k < BlkLen; k++) {
369 QuantAData_offset[k] = 0;
370 }
371
372 return;
373 }
374
375 if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
376 __asm__ volatile(
377 "vsetvli t0, zero, e8, m8 \n\t"
378 "vxor.vv v24, v24, v24 \n\t"
379 "LOOP%=: \n\t"
380 "vsetvli t0, %[CNT], e8, m8 \n\t"
381 "vse8.v v24, (%[DST]) \n\t"
382 "addi %[DST], %[DST], 128 \n\t"
383 "sub %[CNT], %[CNT], t0 \n\t"
384 "bnez %[CNT], LOOP%= \n\t"
385 : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
386 :
387 : "cc", "t0");
388 }
389 if (BlkLen == 16) {
390 float buffer[64] = { 0.0f };
391 __asm__ volatile(
392 "addi t3, zero, 16*8 \n\t"
393 "addi t2, zero, 16 \n\t"
394 "blt %[K], t3, LOOP_K%= \n\t"
395 "blt %[K], t2, TAIL%= \n\t"
396 "LOOP_MAIN%=: \n\t"
397 "vsetvli t1, zero, e32, m2 \n\t"
398 "addi %[K], %[K], -128 \n\t"
399 "vle32.v v0, (%[SRC]) \n\t"
400 "addi %[SRC], %[SRC], 64 \n\t"
401 "vle32.v v2, (%[SRC]) \n\t"
402 "addi %[SRC], %[SRC], 64 \n\t"
403 "vle32.v v4, (%[SRC]) \n\t"
404 "addi %[SRC], %[SRC], 64 \n\t"
405 "vle32.v v6, (%[SRC]) \n\t"
406 "addi %[SRC], %[SRC], 64 \n\t"
407 "vle32.v v8, (%[SRC]) \n\t"
408 "addi %[SRC], %[SRC], 64 \n\t"
409 "vle32.v v10, (%[SRC]) \n\t"
410 "addi %[SRC], %[SRC], 64 \n\t"
411 "vle32.v v12, (%[SRC]) \n\t"
412 "addi %[SRC], %[SRC], 64 \n\t"
413 "vle32.v v14, (%[SRC]) \n\t"
414 "addi %[SRC], %[SRC], 64 \n\t"
415 "addi a1, %[BUFFER], 0 \n\t"
416 "vfabs.v v16, v0 \n\t"
417 "vfabs.v v18, v2 \n\t"
418 "vfabs.v v20, v4 \n\t"
419 "vfabs.v v22, v6 \n\t"
420 "vfabs.v v24, v8 \n\t"
421 "vfabs.v v26, v10 \n\t"
422 "vfabs.v v28, v12 \n\t"
423 "vfabs.v v30, v14 \n\t"
424 "vsetvli t0, zero, e32, m1 \n\t"
425 "vfmax.vv v16, v16, v17 \n\t"
426 "vfmax.vv v18, v18, v19 \n\t"
427 "vfmax.vv v20, v20, v21 \n\t"
428 "vfmax.vv v22, v22, v23 \n\t"
429 "vfmax.vv v24, v24, v25 \n\t"
430 "vfmax.vv v26, v26, v27 \n\t"
431 "vfmax.vv v28, v28, v29 \n\t"
432 "vfmax.vv v30, v30, v31 \n\t"
433 "vse32.v v16, (a1) \n\t"
434 "addi a1, a1, 32 \n\t"
435 "vse32.v v18, (a1) \n\t"
436 "addi a1, a1, 32 \n\t"
437 "vse32.v v20, (a1) \n\t"
438 "addi a1, a1, 32 \n\t"
439 "vse32.v v22, (a1) \n\t"
440 "addi a1, a1, 32 \n\t"
441 "vse32.v v24, (a1) \n\t"
442 "addi a1, a1, 32 \n\t"
443 "vse32.v v26, (a1) \n\t"
444 "addi a1, a1, 32 \n\t"
445 "vse32.v v28, (a1) \n\t"
446 "addi a1, a1, 32 \n\t"
447 "vse32.v v30, (a1) \n\t"
448 "addi a1, %[BUFFER], 0 \n\t"
449 "flw f0, (a1) \n\t"
450 "flw f1, 4(a1) \n\t"
451 "flw f2, 8(a1) \n\t"
452 "flw f3, 12(a1) \n\t"
453 "flw f4, 16(a1) \n\t"
454 "flw f5, 20(a1) \n\t"
455 "flw f6, 24(a1) \n\t"
456 "flw f7, 28(a1) \n\t"
457 "addi a1, a1, 32 \n\t"
458 "fmax.s f1, f0, f1 \n\t"
459 "fmax.s f3, f2, f3 \n\t"
460 "fmax.s f5, f4, f5 \n\t"
461 "fmax.s f7, f6, f7 \n\t"
462 "fmax.s f3, f1, f3 \n\t"
463 "fmax.s f7, f5, f7 \n\t"
464 "fmax.s f10, f3, f7 \n\t"
465 "fmul.s f10, f10, %[RMAXREC] \n\t"
466 "fsw f10, (%[DST]) \n\t"
467 "addi %[DST], %[DST], 20 \n\t"
468 "fdiv.s f10, %[FONE], f10 \n\t"
469 "flw f0, (a1) \n\t"
470 "flw f1, 4(a1) \n\t"
471 "flw f2, 8(a1) \n\t"
472 "flw f3, 12(a1) \n\t"
473 "flw f4, 16(a1) \n\t"
474 "flw f5, 20(a1) \n\t"
475 "flw f6, 24(a1) \n\t"
476 "flw f7, 28(a1) \n\t"
477 "addi a1, a1, 32 \n\t"
478 "fmax.s f1, f0, f1 \n\t"
479 "fmax.s f3, f2, f3 \n\t"
480 "fmax.s f5, f4, f5 \n\t"
481 "fmax.s f7, f6, f7 \n\t"
482 "fmax.s f3, f1, f3 \n\t"
483 "fmax.s f7, f5, f7 \n\t"
484 "fmax.s f11, f3, f7 \n\t"
485 "fmul.s f11, f11, %[RMAXREC] \n\t"
486 "fsw f11, (%[DST]) \n\t"
487 "addi %[DST], %[DST], 20 \n\t"
488 "fdiv.s f11, %[FONE], f11 \n\t"
489 "flw f0, (a1) \n\t"
490 "flw f1, 4(a1) \n\t"
491 "flw f2, 8(a1) \n\t"
492 "flw f3, 12(a1) \n\t"
493 "flw f4, 16(a1) \n\t"
494 "flw f5, 20(a1) \n\t"
495 "flw f6, 24(a1) \n\t"
496 "flw f7, 28(a1) \n\t"
497 "addi a1, a1, 32 \n\t"
498 "fmax.s f1, f0, f1 \n\t"
499 "fmax.s f3, f2, f3 \n\t"
500 "fmax.s f5, f4, f5 \n\t"
501 "fmax.s f7, f6, f7 \n\t"
502 "fmax.s f3, f1, f3 \n\t"
503 "fmax.s f7, f5, f7 \n\t"
504 "fmax.s f12, f3, f7 \n\t"
505 "fmul.s f12, f12, %[RMAXREC] \n\t"
506 "fsw f12, (%[DST]) \n\t"
507 "addi %[DST], %[DST], 20 \n\t"
508 "fdiv.s f12, %[FONE], f12 \n\t"
509 "flw f0, (a1) \n\t"
510 "flw f1, 4(a1) \n\t"
511 "flw f2, 8(a1) \n\t"
512 "flw f3, 12(a1) \n\t"
513 "flw f4, 16(a1) \n\t"
514 "flw f5, 20(a1) \n\t"
515 "flw f6, 24(a1) \n\t"
516 "flw f7, 28(a1) \n\t"
517 "addi a1, a1, 32 \n\t"
518 "fmax.s f1, f0, f1 \n\t"
519 "fmax.s f3, f2, f3 \n\t"
520 "fmax.s f5, f4, f5 \n\t"
521 "fmax.s f7, f6, f7 \n\t"
522 "fmax.s f3, f1, f3 \n\t"
523 "fmax.s f7, f5, f7 \n\t"
524 "fmax.s f13, f3, f7 \n\t"
525 "fmul.s f13, f13, %[RMAXREC] \n\t"
526 "fsw f13, (%[DST]) \n\t"
527 "addi %[DST], %[DST], 20 \n\t"
528 "fdiv.s f13, %[FONE], f13 \n\t"
529 "flw f0, (a1) \n\t"
530 "flw f1, 4(a1) \n\t"
531 "flw f2, 8(a1) \n\t"
532 "flw f3, 12(a1) \n\t"
533 "flw f4, 16(a1) \n\t"
534 "flw f5, 20(a1) \n\t"
535 "flw f6, 24(a1) \n\t"
536 "flw f7, 28(a1) \n\t"
537 "addi a1, a1, 32 \n\t"
538 "fmax.s f1, f0, f1 \n\t"
539 "fmax.s f3, f2, f3 \n\t"
540 "fmax.s f5, f4, f5 \n\t"
541 "fmax.s f7, f6, f7 \n\t"
542 "fmax.s f3, f1, f3 \n\t"
543 "fmax.s f7, f5, f7 \n\t"
544 "fmax.s f14, f3, f7 \n\t"
545 "fmul.s f14, f14, %[RMAXREC] \n\t"
546 "fsw f14, (%[DST]) \n\t"
547 "addi %[DST], %[DST], 20 \n\t"
548 "fdiv.s f14, %[FONE], f14 \n\t"
549 "flw f0, (a1) \n\t"
550 "flw f1, 4(a1) \n\t"
551 "flw f2, 8(a1) \n\t"
552 "flw f3, 12(a1) \n\t"
553 "flw f4, 16(a1) \n\t"
554 "flw f5, 20(a1) \n\t"
555 "flw f6, 24(a1) \n\t"
556 "flw f7, 28(a1) \n\t"
557 "addi a1, a1, 32 \n\t"
558 "fmax.s f1, f0, f1 \n\t"
559 "fmax.s f3, f2, f3 \n\t"
560 "fmax.s f5, f4, f5 \n\t"
561 "fmax.s f7, f6, f7 \n\t"
562 "fmax.s f3, f1, f3 \n\t"
563 "fmax.s f7, f5, f7 \n\t"
564 "fmax.s f15, f3, f7 \n\t"
565 "fmul.s f15, f15, %[RMAXREC] \n\t"
566 "fsw f15, (%[DST]) \n\t"
567 "addi %[DST], %[DST], 20 \n\t"
568 "fdiv.s f15, %[FONE], f15 \n\t"
569 "flw f0, (a1) \n\t"
570 "flw f1, 4(a1) \n\t"
571 "flw f2, 8(a1) \n\t"
572 "flw f3, 12(a1) \n\t"
573 "flw f4, 16(a1) \n\t"
574 "flw f5, 20(a1) \n\t"
575 "flw f6, 24(a1) \n\t"
576 "flw f7, 28(a1) \n\t"
577 "addi a1, a1, 32 \n\t"
578 "fmax.s f1, f0, f1 \n\t"
579 "fmax.s f3, f2, f3 \n\t"
580 "fmax.s f5, f4, f5 \n\t"
581 "fmax.s f7, f6, f7 \n\t"
582 "fmax.s f3, f1, f3 \n\t"
583 "fmax.s f7, f5, f7 \n\t"
584 "fmax.s f16, f3, f7 \n\t"
585 "fmul.s f16, f16, %[RMAXREC] \n\t"
586 "fsw f16, (%[DST]) \n\t"
587 "addi %[DST], %[DST], 20 \n\t"
588 "fdiv.s f16, %[FONE], f16 \n\t"
589 "flw f0, (a1) \n\t"
590 "flw f1, 4(a1) \n\t"
591 "flw f2, 8(a1) \n\t"
592 "flw f3, 12(a1) \n\t"
593 "flw f4, 16(a1) \n\t"
594 "flw f5, 20(a1) \n\t"
595 "flw f6, 24(a1) \n\t"
596 "flw f7, 28(a1) \n\t"
597 "addi a1, a1, 32 \n\t"
598 "fmax.s f1, f0, f1 \n\t"
599 "fmax.s f3, f2, f3 \n\t"
600 "fmax.s f5, f4, f5 \n\t"
601 "fmax.s f7, f6, f7 \n\t"
602 "fmax.s f3, f1, f3 \n\t"
603 "fmax.s f7, f5, f7 \n\t"
604 "fmax.s f17, f3, f7 \n\t"
605 "fmul.s f17, f17, %[RMAXREC] \n\t"
606 "fsw f17, (%[DST]) \n\t"
607 "addi %[DST], %[DST], -136 \n\t"
608 "fdiv.s f17, %[FONE], f17 \n\t"
609 "vsetvli t0, zero, e32, m2 \n\t"
610 "vfmul.vf v16, v0, f10 \n\t"
611 "vfmul.vf v18, v2, f11 \n\t"
612 "vfmul.vf v20, v4, f12 \n\t"
613 "vfmul.vf v22, v6, f13 \n\t"
614 "vfmul.vf v24, v8, f14 \n\t"
615 "vfmul.vf v26, v10, f15 \n\t"
616 "vfmul.vf v28, v12, f16 \n\t"
617 "vfmul.vf v30, v14, f17 \n\t"
618 "vfcvt.x.f.v v16, v16 \n\t"
619 "vfcvt.x.f.v v18, v18 \n\t"
620 "vfcvt.x.f.v v20, v20 \n\t"
621 "vfcvt.x.f.v v22, v22 \n\t"
622 "vfcvt.x.f.v v24, v24 \n\t"
623 "vfcvt.x.f.v v26, v26 \n\t"
624 "vfcvt.x.f.v v28, v28 \n\t"
625 "vfcvt.x.f.v v30, v30 \n\t"
626 "vsetvli t0, zero, e16, m1 \n\t"
627 "vnclip.wx v16, v16, zero \n\t"
628 "vnclip.wx v18, v18, zero \n\t"
629 "vnclip.wx v20, v20, zero \n\t"
630 "vnclip.wx v22, v22, zero \n\t"
631 "vnclip.wx v24, v24, zero \n\t"
632 "vnclip.wx v26, v26, zero \n\t"
633 "vnclip.wx v28, v28, zero \n\t"
634 "vnclip.wx v30, v30, zero \n\t"
635 "vsetvli t0, t1, e8, mf2 \n\t"
636 "vnclip.wx v16, v16, zero \n\t"
637 "vnclip.wx v18, v18, zero \n\t"
638 "vnclip.wx v20, v20, zero \n\t"
639 "vnclip.wx v22, v22, zero \n\t"
640 "vnclip.wx v24, v24, zero \n\t"
641 "vnclip.wx v26, v26, zero \n\t"
642 "vnclip.wx v28, v28, zero \n\t"
643 "vnclip.wx v30, v30, zero \n\t"
644 "vse8.v v16, (%[DST]) \n\t"
645 "addi %[DST], %[DST], 20 \n\t"
646 "vse8.v v18, (%[DST]) \n\t"
647 "addi %[DST], %[DST], 20 \n\t"
648 "vse8.v v20, (%[DST]) \n\t"
649 "addi %[DST], %[DST], 20 \n\t"
650 "vse8.v v22, (%[DST]) \n\t"
651 "addi %[DST], %[DST], 20 \n\t"
652 "vse8.v v24, (%[DST]) \n\t"
653 "addi %[DST], %[DST], 20 \n\t"
654 "vse8.v v26, (%[DST]) \n\t"
655 "addi %[DST], %[DST], 20 \n\t"
656 "vse8.v v28, (%[DST]) \n\t"
657 "addi %[DST], %[DST], 20 \n\t"
658 "vse8.v v30, (%[DST]) \n\t"
659 "addi %[DST], %[DST], 16 \n\t"
660 "bge %[K], t3, LOOP_MAIN%= \n\t"
661 "blt %[K], t2, TAIL%= \n\t"
662 "LOOP_K%=: \n\t"
663 "vsetvli t1, %[K], e32, m2 \n\t"
664 "vle32.v v0, (%[SRC]) \n\t"
665 "addi %[SRC], %[SRC], 64 \n\t"
666 "sub %[K], %[K], t1 \n\t"
667 "vfabs.v v16, v0 \n\t"
668 "vsetvli t0, zero, e32, m1 \n\t"
669 "vfmax.vv v16, v16, v17 \n\t"
670 "vse32.v v16, (%[BUFFER]) \n\t"
671 "flw f0, (%[BUFFER]) \n\t"
672 "flw f1, 4(%[BUFFER]) \n\t"
673 "flw f2, 8(%[BUFFER]) \n\t"
674 "flw f3, 12(%[BUFFER]) \n\t"
675 "flw f4, 16(%[BUFFER]) \n\t"
676 "flw f5, 20(%[BUFFER]) \n\t"
677 "flw f6, 24(%[BUFFER]) \n\t"
678 "flw f7, 28(%[BUFFER]) \n\t"
679 "fmax.s f1, f0, f1 \n\t"
680 "fmax.s f3, f2, f3 \n\t"
681 "fmax.s f5, f4, f5 \n\t"
682 "fmax.s f7, f6, f7 \n\t"
683 "fmax.s f3, f1, f3 \n\t"
684 "fmax.s f7, f5, f7 \n\t"
685 "fmax.s f10, f3, f7 \n\t"
686 "fmul.s f10, f10, %[RMAXREC] \n\t"
687 "fsw f10, (%[DST]) \n\t"
688 "addi %[DST], %[DST], 4 \n\t"
689 "fdiv.s f11, %[FONE], f10 \n\t"
690 "vsetvli t0, zero, e32, m2 \n\t"
691 "vfmul.vf v16, v0, f11 \n\t"
692 "vfcvt.x.f.v v16, v16 \n\t"
693 "vsetvli t0, zero, e16, m1 \n\t"
694 "vnclip.wx v16, v16, zero \n\t"
695 "vsetvli t0, t1, e8, mf2 \n\t"
696 "vnclip.wx v16, v16, zero \n\t"
697 "vse8.v v16, (%[DST]) \n\t"
698 "addi %[DST], %[DST], 16 \n\t"
699 "bge %[K], t2, LOOP_K%= \n\t"
700 "TAIL%=: \n\t"
701 "blez %[K], END%= \n\t"
702 "vsetvli t0, t3, e32, m2 \n\t"
703 "vxor.vv v16, v16, v16 \n\t"
704 "jal x0, LOOP_K%= \n\t"
705 "END%=: \n\t"
706 : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
707 : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
708 : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
709 "f13", "f14", "f15", "f16", "f17");
710 } else if (BlkLen == 32) {
711 __asm__ volatile(
712 "addi t3, zero, 32*4 \n\t"
713 "addi t2, zero, 32 \n\t"
714
715 "addi a1, %[SRC], 0 \n\t"
716 "addi a2, %[SRC], 128 \n\t"
717 "addi a3, %[SRC], 256 \n\t"
718 "addi a4, %[SRC], 384 \n\t"
719
720 "addi s1, %[DST], 0 \n\t"
721 "addi s2, %[DST], 36 \n\t"
722 "addi s3, %[DST], 72 \n\t"
723 "addi s4, %[DST], 108 \n\t"
724 "blt %[K], t3, LOOP_K%= \n\t"
725 "blt %[K], t2, TAIL%= \n\t"
726
727 "LOOP_MAIN%=: \n\t"
728 "vsetvli t1, zero, e32, m4 \n\t"
729 "addi %[K], %[K], -128 \n\t"
730 "vle32.v v0, (a1) \n\t"
731 "addi a1, a1, 512 \n\t"
732 "vle32.v v4, (a2) \n\t"
733 "addi a2, a2, 512 \n\t"
734 "vle32.v v8, (a3) \n\t"
735 "addi a3, a3, 512 \n\t"
736 "vle32.v v12, (a4) \n\t"
737 "addi a4, a4, 512 \n\t"
738 "vfabs.v v16, v0 \n\t"
739 "vfabs.v v20, v4 \n\t"
740 "vfabs.v v24, v8 \n\t"
741 "vfabs.v v28, v12 \n\t"
742 "vsetvli t0, zero, e32, m2 \n\t"
743 "vfmax.vv v16, v16, v18 \n\t"
744 "vfmax.vv v20, v20, v22 \n\t"
745 "vfmax.vv v24, v24, v26 \n\t"
746 "vfmax.vv v28, v28, v30 \n\t"
747 "vsetvli t0, zero, e32, m1 \n\t"
748 "vfmax.vv v16, v16, v17 \n\t"
749 "vfmax.vv v20, v20, v21 \n\t"
750 "vfmax.vv v24, v24, v25 \n\t"
751 "vfmax.vv v28, v28, v29 \n\t"
752
753 "vfredmax.vs v17, v16, v17 \n\t"
754 "vfredmax.vs v21, v20, v21 \n\t"
755 "vfredmax.vs v25, v24, v25 \n\t"
756 "vfredmax.vs v29, v28, v29 \n\t"
757 "vfmv.f.s f10, v17 \n\t"
758 "vfmv.f.s f11, v21 \n\t"
759 "vfmv.f.s f12, v25 \n\t"
760 "vfmv.f.s f13, v29 \n\t"
761
762 "fmul.s f10, f10, %[RMAXREC] \n\t"
763 "fmul.s f11, f11, %[RMAXREC] \n\t"
764 "fmul.s f12, f12, %[RMAXREC] \n\t"
765 "fmul.s f13, f13, %[RMAXREC] \n\t"
766 "fsw f10, (s1) \n\t"
767 "addi s1, s1, 4 \n\t"
768
769 "fsw f11, (s2) \n\t"
770 "addi s2, s2, 4 \n\t"
771 "fsw f12, (s3) \n\t"
772 "addi s3, s3, 4 \n\t"
773 "fsw f13, (s4) \n\t"
774 "addi s4, s4, 4 \n\t"
775 "fdiv.s f10, %[FONE], f10 \n\t"
776 "fdiv.s f11, %[FONE], f11 \n\t"
777 "fdiv.s f12, %[FONE], f12 \n\t"
778 "fdiv.s f13, %[FONE], f13 \n\t"
779 "vsetvli t0, zero, e32, m4 \n\t"
780 "vfmul.vf v16, v0, f10 \n\t"
781 "vfmul.vf v20, v4, f11 \n\t"
782 "vfmul.vf v24, v8, f12 \n\t"
783 "vfmul.vf v28, v12, f13 \n\t"
784 "vfcvt.x.f.v v16, v16 \n\t"
785 "vfcvt.x.f.v v20, v20 \n\t"
786 "vfcvt.x.f.v v24, v24 \n\t"
787 "vfcvt.x.f.v v28, v28 \n\t"
788 "vsetvli t0, zero, e16, m2 \n\t"
789 "vnclip.wx v16, v16, zero \n\t"
790 "vnclip.wx v20, v20, zero \n\t"
791 "vnclip.wx v24, v24, zero \n\t"
792 "vnclip.wx v28, v28, zero \n\t"
793 "vsetvli t0, t1, e8, m1 \n\t"
794 "vnclip.wx v16, v16, zero \n\t"
795 "vnclip.wx v20, v20, zero \n\t"
796 "vnclip.wx v24, v24, zero \n\t"
797 "vnclip.wx v28, v28, zero \n\t"
798 "vse8.v v16, (s1) \n\t"
799 "addi s1, s1, 140 \n\t"
800 "vse8.v v20, (s2) \n\t"
801 "addi s2, s2, 140 \n\t"
802 "vse8.v v24, (s3) \n\t"
803 "addi s3, s3, 140 \n\t"
804 "vse8.v v28, (s4) \n\t"
805 "addi s4, s4, 140 \n\t"
806 "bge %[K], t3, LOOP_MAIN%= \n\t"
807 "blt %[K], t2, TAIL%= \n\t"
808 "LOOP_K%=: \n\t"
809 "vsetvli t1, %[K], e32, m4 \n\t"
810 "vle32.v v0, (a1) \n\t"
811 "addi a1, a1, 128 \n\t"
812 "sub %[K], %[K], t1 \n\t"
813 "vfabs.v v16, v0 \n\t"
814 "vsetvli t0, zero, e32, m2 \n\t"
815 "vfmax.vv v16, v16, v18 \n\t"
816 "vsetvli t0, zero, e32, m1 \n\t"
817 "vfmax.vv v16, v16, v17 \n\t"
818 "vfredmax.vs v17, v16, v17 \n\t"
819 "vfmv.f.s f10, v17 \n\t"
820
821 "fmul.s f10, f10, %[RMAXREC] \n\t"
822 "fsw f10, (s1) \n\t"
823 "addi s1, s1, 4 \n\t"
824 "fdiv.s f11, %[FONE], f10 \n\t"
825 "vsetvli t0, zero, e32, m4 \n\t"
826 "vfmul.vf v16, v0, f11 \n\t"
827 "vfcvt.x.f.v v16, v16 \n\t"
828 "vsetvli t0, zero, e16, m2 \n\t"
829 "vnclip.wx v16, v16, zero \n\t"
830 "vsetvli t0, zero, e8, m1 \n\t"
831 "vnclip.wx v16, v16, zero \n\t"
832 "vse8.v v16, (s1) \n\t"
833 "addi s1, s1, 32 \n\t"
834 "bge %[K], t2, LOOP_K%= \n\t"
835 "TAIL%=: \n\t"
836 "blez %[K], END%= \n\t"
837 "vsetvli t0, t3, e32, m4 \n\t"
838 "vxor.vv v0, v0, v0 \n\t"
839 "vxor.vv v16, v16, v16 \n\t"
840 "jal x0, LOOP_K%= \n\t"
841 "END%=: \n\t"
842 : [K] "+r"(CountK)
843 : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
844 : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
845 } else if (BlkLen == 64) {
846 __asm__ volatile(
847 "addi t3, zero, 64*2 \n\t"
848 "addi t2, zero, 64 \n\t"
849 "addi a1, %[SRC], 0 \n\t"
850 "addi a2, %[SRC], 256 \n\t"
851 "addi s1, %[DST], 0 \n\t"
852 "addi s2, %[DST], 68 \n\t"
853 "blt %[K], t3, LOOP_K%= \n\t"
854 "blt %[K], t2, TAIL%= \n\t"
855 "LOOP_MAIN%=: \n\t"
856 "vsetvli t1, zero, e32, m8 \n\t"
857 "addi %[K], %[K], -128 \n\t"
858 "vle32.v v0, (a1) \n\t"
859 "addi a1, a1, 512 \n\t"
860 "vle32.v v8, (a2) \n\t"
861 "addi a2, a2, 512 \n\t"
862 "vfabs.v v16, v0 \n\t"
863 "vfabs.v v24, v8 \n\t"
864 "vsetvli t0, zero, e32, m4 \n\t"
865 "vfmax.vv v16, v16, v20 \n\t"
866 "vfmax.vv v24, v24, v28 \n\t"
867 "vsetvli t0, zero, e32, m2 \n\t"
868 "vfmax.vv v16, v16, v18 \n\t"
869 "vfmax.vv v24, v24, v26 \n\t"
870 "vsetvli t0, zero, e32, m1 \n\t"
871 "vfmax.vv v16, v16, v17 \n\t"
872 "vfmax.vv v24, v24, v25 \n\t"
873 "vfredmax.vs v17, v16, v17 \n\t"
874 "vfredmax.vs v25, v24, v25 \n\t"
875 "vfmv.f.s f10, v17 \n\t"
876 "vfmv.f.s f11, v25 \n\t"
877 "fmul.s f10, f10, %[RMAXREC] \n\t"
878 "fmul.s f11, f11, %[RMAXREC] \n\t"
879 "fsw f10, (s1) \n\t"
880 "addi s1, s1, 4 \n\t"
881 "fsw f11, (s2) \n\t"
882 "addi s2, s2, 4 \n\t"
883 "fdiv.s f10, %[FONE], f10 \n\t"
884 "fdiv.s f11, %[FONE], f11 \n\t"
885 "vsetvli t0, zero, e32, m8 \n\t"
886 "vfmul.vf v16, v0, f10 \n\t"
887 "vfmul.vf v24, v8, f11 \n\t"
888 "vfcvt.x.f.v v16, v16 \n\t"
889 "vfcvt.x.f.v v24, v24 \n\t"
890 "vsetvli t0, zero, e16, m4 \n\t"
891 "vnclip.wx v16, v16, zero \n\t"
892 "vnclip.wx v24, v24, zero \n\t"
893 "vsetvli t0, t1, e8, m2 \n\t"
894 "vnclip.wx v16, v16, zero \n\t"
895 "vnclip.wx v24, v24, zero \n\t"
896 "vse8.v v16, (s1) \n\t"
897 "addi s1, s1, 132 \n\t"
898 "vse8.v v24, (s2) \n\t"
899 "addi s2, s2, 132 \n\t"
900 "bge %[K], t3, LOOP_MAIN%= \n\t"
901 "blt %[K], t2, TAIL%= \n\t"
902 "LOOP_K%=: \n\t"
903 "vsetvli t1, %[K], e32, m8 \n\t"
904 "vle32.v v0, (a1) \n\t"
905 "addi a1, a1, 256 \n\t"
906 "sub %[K], %[K], t1 \n\t"
907 "vfabs.v v16, v0 \n\t"
908 "vsetvli t0, zero, e32, m4 \n\t"
909 "vfmax.vv v16, v16, v20 \n\t"
910 "vsetvli t0, zero, e32, m2 \n\t"
911 "vfmax.vv v16, v16, v18 \n\t"
912 "vsetvli t0, zero, e32, m1 \n\t"
913 "vfmax.vv v16, v16, v17 \n\t"
914 "vfredmax.vs v17, v16, v17 \n\t"
915 "vfmv.f.s f10, v17 \n\t"
916 "fmul.s f10, f10, %[RMAXREC] \n\t"
917 "fsw f10, (s1) \n\t"
918 "addi s1, s1, 4 \n\t"
919 "fdiv.s f11, %[FONE], f10 \n\t"
920 "vsetvli t0, zero, e32, m8 \n\t"
921 "vfmul.vf v16, v0, f11 \n\t"
922 "vfcvt.x.f.v v16, v16 \n\t"
923 "vsetvli t0, zero, e16, m4 \n\t"
924 "vnclip.wx v16, v16, zero \n\t"
925 "vsetvli t0, zero, e8, m2 \n\t"
926 "vnclip.wx v16, v16, zero \n\t"
927 "vse8.v v16, (s1) \n\t"
928 "addi s1, s1, 64 \n\t"
929 "bge %[K], t2, LOOP_K%= \n\t"
930 "TAIL%=: \n\t"
931 "blez %[K], END%= \n\t"
932 "vsetvli t0, t3, e32, m8 \n\t"
933 "vxor.vv v0, v0, v0 \n\t"
934 "vxor.vv v16, v16, v16 \n\t"
935 "jal x0, LOOP_K%= \n\t"
936 "END%=: \n\t"
937 : [K] "+r"(CountK)
938 : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
939 : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
940 } else if (BlkLen == 128) {
941 __asm__ volatile(
942 "addi t2, zero, 128 \n\t"
943 "addi a1, %[SRC], 0 \n\t"
944 "addi a2, %[SRC], 256 \n\t"
945 "blt %[K], t2, TAIL%= \n\t"
946 "LOOP_K%=: \n\t"
947 "vsetvli t1, zero, e32, m8 \n\t"
948 "vle32.v v0, (a1) \n\t"
949 "addi a1, a1, 512 \n\t"
950 "vle32.v v8, (a2) \n\t"
951 "addi a2, a2, 512 \n\t"
952 "sub %[K], %[K], t2 \n\t"
953 "QUANT%=: \n\t"
954 "vfabs.v v16, v0 \n\t"
955 "vfabs.v v24, v8 \n\t"
956 "vfmax.vv v24, v16, v24 \n\t"
957 "vsetvli t1, zero, e32, m4 \n\t"
958 "vfmax.vv v28, v24, v28 \n\t"
959 "vsetvli t0, zero, e32, m2 \n\t"
960 "vfmax.vv v30, v28, v30 \n\t"
961 "vsetvli t0, zero, e32, m1 \n\t"
962 "vfmax.vv v30, v30, v31 \n\t"
963 "vfredmax.vs v31, v30, v31 \n\t"
964 "vfmv.f.s f10, v31 \n\t"
965 "fmul.s f10, f10, %[RMAXREC] \n\t"
966 "fsw f10, (%[DST]) \n\t"
967 "addi %[DST], %[DST], 4 \n\t"
968 "fdiv.s f11, %[FONE], f10 \n\t"
969 "vsetvli t0, zero, e32, m8 \n\t"
970 "vfmul.vf v16, v0, f11 \n\t"
971 "vfmul.vf v24, v8, f11 \n\t"
972 "vfcvt.x.f.v v16, v16 \n\t"
973 "vfcvt.x.f.v v24, v24 \n\t"
974 "vsetvli t0, zero, e16, m4 \n\t"
975 "vnclip.wx v16, v16, zero \n\t"
976 "vnclip.wx v20, v24, zero \n\t"
977 "vsetvli t0, zero, e8, m4 \n\t"
978 "vnclip.wx v16, v16, zero \n\t"
979 "vse8.v v16, (%[DST]) \n\t"
980 "addi %[DST], %[DST], 128 \n\t"
981 "bge %[K], t2, LOOP_K%= \n\t"
982 "TAIL%=: \n\t"
983 "blez %[K], END%= \n\t"
984 "vsetvli t1, zero, e32, m8 \n\t"
985 "vxor.vv v0, v0, v0 \n\t"
986 "vxor.vv v8, v8, v8 \n\t"
987 "vsetvli t0, %[K], e32, m8 \n\t"
988 "vle32.v v0, (a1) \n\t"
989 "sub %[K], %[K], t0 \n\t"
990 "vsetvli t0, %[K], e32, m8 \n\t"
991 "vle32.v v8, (a2) \n\t"
992 "sub %[K], %[K], t0 \n\t"
993 "vsetvli t1, zero, e32, m8 \n\t"
994 "jal x0, QUANT%= \n\t"
995 "END%=: \n\t"
996
997 : [DST] "+r"(DST), [K] "+r"(CountK)
998 : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
999 : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
1000 } else {
1001 float buffer[8] = { 0.0f };
1002 size_t cnt = BlkLen / 256;
1003
1004 __asm__ volatile(
1005 "slli t3, %[BLK], 2 \n\t"
1006 "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
1007 "LOOP_MAIN%=: \n\t"
1008 "vsetvli t0, zero, e32, m1 \n\t"
1009 "vxor.vv v31, v31, v31 \n\t"
1010 "vse32.v v31, (%[BUFFER]) \n\t"
1011 "addi t6, %[CNT], 0 \n\t"
1012 "LOOP_CMP%=: \n\t"
1013 "addi t6, t6, -1 \n\t"
1014 "vsetvli t0, zero, e32, m8 \n\t"
1015 "vle32.v v0, (%[SRC]) \n\t"
1016 "addi %[SRC], %[SRC], 256 \n\t"
1017 "vle32.v v8, (%[SRC]) \n\t"
1018 "addi %[SRC], %[SRC], 256 \n\t"
1019 "vle32.v v16, (%[SRC]) \n\t"
1020 "addi %[SRC], %[SRC], 256 \n\t"
1021 "vle32.v v24, (%[SRC]) \n\t"
1022 "addi %[SRC], %[SRC], 256 \n\t"
1023 "vfabs.v v0, v0 \n\t"
1024 "vfabs.v v8, v8 \n\t"
1025 "vfabs.v v16, v16 \n\t"
1026 "vfabs.v v24, v24 \n\t"
1027 "vfmax.vv v8, v0, v8 \n\t"
1028 "vfmax.vv v16, v16, v24 \n\t"
1029 "vfmax.vv v0, v0, v16 \n\t"
1030 "vsetvli t0, zero, e32, m4 \n\t"
1031 "vfmax.vv v0, v0, v4 \n\t"
1032 "vsetvli t0, zero, e32, m2 \n\t"
1033 "vfmax.vv v0, v0, v2 \n\t"
1034 "vsetvli t0, zero, e32, m1 \n\t"
1035 "vfmax.vv v0, v0, v1 \n\t"
1036 "vle32.v v30, (%[BUFFER]) \n\t"
1037 "vfmax.vv v31, v30, v0 \n\t"
1038 "vse32.v v31, (%[BUFFER]) \n\t"
1039 "bnez t6, LOOP_CMP%= \n\t"
1040 "sub %[SRC], %[SRC], t3 \n\t"
1041 "addi t6, %[CNT], 0 \n\t"
1042 "flw f0, (%[BUFFER]) \n\t"
1043 "flw f1, 4(%[BUFFER]) \n\t"
1044 "flw f2, 8(%[BUFFER]) \n\t"
1045 "flw f3, 12(%[BUFFER]) \n\t"
1046 "flw f4, 16(%[BUFFER]) \n\t"
1047 "flw f5, 20(%[BUFFER]) \n\t"
1048 "flw f6, 24(%[BUFFER]) \n\t"
1049 "flw f7, 28(%[BUFFER]) \n\t"
1050 "fmax.s f1, f0, f1 \n\t"
1051 "fmax.s f3, f2, f3 \n\t"
1052 "fmax.s f5, f4, f5 \n\t"
1053 "fmax.s f7, f6, f7 \n\t"
1054 "fmax.s f3, f1, f3 \n\t"
1055 "fmax.s f7, f5, f7 \n\t"
1056 "fmax.s f10, f3, f7 \n\t"
1057 "fmul.s f10, f10, %[RMAXREC] \n\t"
1058 "fsw f10, (%[DST]) \n\t"
1059 "addi %[DST], %[DST], 4 \n\t"
1060 "fdiv.s f11, %[FONE], f10 \n\t"
1061 "addi t6, %[CNT], 0 \n\t"
1062 "LOOP_QUANT%=: \n\t"
1063 "addi t6, t6, -1 \n\t"
1064 "vsetvli t0, zero, e32, m8 \n\t"
1065 "vle32.v v0, (%[SRC]) \n\t"
1066 "addi %[SRC], %[SRC], 256 \n\t"
1067 "vle32.v v8, (%[SRC]) \n\t"
1068 "addi %[SRC], %[SRC], 256 \n\t"
1069 "vle32.v v16, (%[SRC]) \n\t"
1070 "addi %[SRC], %[SRC], 256 \n\t"
1071 "vle32.v v24, (%[SRC]) \n\t"
1072 "addi %[SRC], %[SRC], 256 \n\t"
1073 "vsetvli t0, zero, e32, m8 \n\t"
1074 "vfmul.vf v0, v0, f11 \n\t"
1075 "vfmul.vf v8, v8, f11 \n\t"
1076 "vfmul.vf v16, v16, f11 \n\t"
1077 "vfmul.vf v24, v24, f11 \n\t"
1078 "vfcvt.x.f.v v0, v0 \n\t"
1079 "vfcvt.x.f.v v8, v8 \n\t"
1080 "vfcvt.x.f.v v16, v16 \n\t"
1081 "vfcvt.x.f.v v24, v24 \n\t"
1082 "vsetvli t0, zero, e16, m4 \n\t"
1083 "vnclip.wx v0, v0, zero \n\t"
1084 "vnclip.wx v4, v8, zero \n\t"
1085 "vnclip.wx v8, v16, zero \n\t"
1086 "vnclip.wx v12, v24, zero \n\t"
1087 "vsetvli t0, zero, e8, m4 \n\t"
1088 "vnclip.wx v0, v0, zero \n\t"
1089 "vnclip.wx v4, v8, zero \n\t"
1090 "vse8.v v0, (%[DST]) \n\t"
1091 "addi %[DST], %[DST], 128 \n\t"
1092 "vse8.v v4, (%[DST]) \n\t"
1093 "addi %[DST], %[DST], 128 \n\t"
1094 "bnez t6, LOOP_QUANT%= \n\t"
1095 "sub %[K], %[K], %[BLK] \n\t"
1096 "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
1097 "blez %[K], END%= \n\t"
1098 "LOOP_TAIL%=: \n\t"
1099 "vsetvli t0, zero, e32, m1 \n\t"
1100 "vxor.vv v31, v31, v31 \n\t"
1101 "vse32.v v31, (%[BUFFER]) \n\t"
1102 "addi t6, %[K], 0 \n\t"
1103 "addi s1, %[SRC], 0 \n\t"
1104 "TAIL_CMP%=: \n\t"
1105 "vsetvli t0, zero, e32, m8 \n\t"
1106 "vxor.vv v0, v0, v0 \n\t"
1107 "vsetvli t0, t6, e32, m8 \n\t"
1108 "vle32.v v0, (%[SRC]) \n\t"
1109 "addi %[SRC], %[SRC], 256 \n\t"
1110 "sub t6, t6, t0 \n\t"
1111 "vfabs.v v0, v0 \n\t"
1112 "vsetvli t0, zero, e32, m4 \n\t"
1113 "vfmax.vv v0, v0, v4 \n\t"
1114 "vsetvli t0, zero, e32, m2 \n\t"
1115 "vfmax.vv v0, v0, v2 \n\t"
1116 "vsetvli t0, zero, e32, m1 \n\t"
1117 "vfmax.vv v0, v0, v1 \n\t"
1118 "vle32.v v30, (%[BUFFER]) \n\t"
1119 "vfmax.vv v31, v30, v0 \n\t"
1120 "vse32.v v31, (%[BUFFER]) \n\t"
1121 "bnez t6, TAIL_CMP%= \n\t"
1122 "addi t6, %[K], 0 \n\t"
1123 "flw f0, (%[BUFFER]) \n\t"
1124 "flw f1, 4(%[BUFFER]) \n\t"
1125 "flw f2, 8(%[BUFFER]) \n\t"
1126 "flw f3, 12(%[BUFFER]) \n\t"
1127 "flw f4, 16(%[BUFFER]) \n\t"
1128 "flw f5, 20(%[BUFFER]) \n\t"
1129 "flw f6, 24(%[BUFFER]) \n\t"
1130 "flw f7, 28(%[BUFFER]) \n\t"
1131 "fmax.s f1, f0, f1 \n\t"
1132 "fmax.s f3, f2, f3 \n\t"
1133 "fmax.s f5, f4, f5 \n\t"
1134 "fmax.s f7, f6, f7 \n\t"
1135 "fmax.s f3, f1, f3 \n\t"
1136 "fmax.s f7, f5, f7 \n\t"
1137 "fmax.s f10, f3, f7 \n\t"
1138 "fmul.s f10, f10, %[RMAXREC] \n\t"
1139 "fsw f10, (%[DST]) \n\t"
1140 "addi %[DST], %[DST], 4 \n\t"
1141 "fdiv.s f11, %[FONE], f10 \n\t"
1142 "addi t6, %[K], 0 \n\t"
1143 "TAIL_QUANT%=: \n\t"
1144 "vsetvli t0, zero, e32, m8 \n\t"
1145 "vxor.vv v0, v0, v0 \n\t"
1146 "vsetvli t1, t6, e32, m8 \n\t"
1147 "vle32.v v0, (s1) \n\t"
1148 "addi s1, s1, 256 \n\t"
1149 "sub t6, t6, t1 \n\t"
1150 "vsetvli t0, zero, e32, m8 \n\t"
1151 "vfmul.vf v0, v0, f11 \n\t"
1152 "vfcvt.x.f.v v0, v0 \n\t"
1153 "vsetvli t0, zero, e16, m4 \n\t"
1154 "vnclip.wx v0, v0, zero \n\t"
1155 "vsetvli t0, t1, e8, m2 \n\t"
1156 "vnclip.wx v0, v0, zero \n\t"
1157 "vse8.v v0, (%[DST]) \n\t"
1158 "addi %[DST], %[DST], 64 \n\t"
1159 "bnez t6, TAIL_QUANT%= \n\t"
1160 "END%=: \n\t"
1161 : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
1162 : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
1163 [CNT] "r"(cnt)
1164 : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
1165 }
1166}
1167
1168} // namespace ime1
1169
1170namespace {
1171#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
1172 "vmadot v16, v14, v0 \n\t" \
1173 "vmadot v18, v14, v1 \n\t" \
1174 "vmadot v20, v14, v2 \n\t" \
1175 "vmadot v22, v14, v3 \n\t" \
1176 "vmadot v16, v15, v4 \n\t" \
1177 "vmadot v18, v15, v5 \n\t" \
1178 "vmadot v20, v15, v6 \n\t" \
1179 "vmadot v22, v15, v7 \n\t"
1180
1181#define SQ4BIT_KERNEL_ACC_1X4X4 \
1182 "vfcvt.f.x.v v16, v16 \n\t" \
1183 "vfcvt.f.x.v v18, v18 \n\t" \
1184 "vfcvt.f.x.v v20, v20 \n\t" \
1185 "vfcvt.f.x.v v22, v22 \n\t" \
1186 "addi s2, s1, 16 \n\t" \
1187 "addi s3, s1, 32 \n\t" \
1188 "addi s4, s1, 48 \n\t" \
1189 "addi s6, s5, 12 \n\t" \
1190 "vfmacc.vv v28, v16, v24 \n\t" \
1191 "vfmacc.vv v29, v18, v25 \n\t" \
1192 "vfmacc.vv v30, v20, v26 \n\t" \
1193 "vfmacc.vv v31, v22, v27 \n\t"
1194
1195#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
1196 "vfcvt.f.x.v v16, v16 \n\t" \
1197 "vfcvt.f.x.v v18, v18 \n\t" \
1198 "vfcvt.f.x.v v20, v20 \n\t" \
1199 "vfcvt.f.x.v v22, v22 \n\t" \
1200 "addi s2, s1, 8 \n\t" \
1201 "addi s3, s1, 16 \n\t" \
1202 "addi s4, s1, 24 \n\t" \
1203 "addi s6, s5, 12 \n\t" \
1204 "vfmacc.vv v28, v16, v24 \n\t" \
1205 "vfmacc.vv v29, v18, v25 \n\t" \
1206 "vfmacc.vv v30, v20, v26 \n\t" \
1207 "vfmacc.vv v31, v22, v27 \n\t"
1208
1209#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
1210 "vle8.v v4, (s1) \n\t" \
1211 "addi s1, s1, 128 \n\t" \
1212 "vle8.v v5, (s2) \n\t" \
1213 "addi s2, s2, 128 \n\t" \
1214 "vle8.v v6, (s3) \n\t" \
1215 "addi s3, s3, 128 \n\t" \
1216 "vle8.v v7, (s4) \n\t" \
1217 "addi s4, s4, 128 \n\t" \
1218 "vsetvli t0, zero, e8, mf4 \n\t" \
1219 "vle8.v v14, (s5) \n\t" \
1220 "addi s5, s5, 16 \n\t" \
1221 "vle8.v v15, (s6) \n\t" \
1222 "addi s6, s6, 16 \n\t" \
1223 "addi t5, t5, -1 \n\t" \
1224 "vsetvli t0, zero, e8, m1 \n\t" \
1225 "vand.vi v0, v4, 15 \n\t" \
1226 "vand.vi v1, v5, 15 \n\t" \
1227 "vand.vi v2, v6, 15 \n\t" \
1228 "vand.vi v3, v7, 15 \n\t" \
1229 "vsrl.vi v4, v4, 4 \n\t" \
1230 "vsrl.vi v5, v5, 4 \n\t" \
1231 "vsrl.vi v6, v6, 4 \n\t" \
1232 "vsrl.vi v7, v7, 4 \n\t"
1233
1234#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
1235 "vsetvli t0, zero, e8, mf2 \n\t" \
1236 "vle8.v v1, (s7) \n\t" \
1237 "vsetvli t0, zero, e8, m1 \n\t" \
1238 "vrgather.vv v8, v1, v13 \n\t" \
1239 "vadd.vi v13, v13, 4 \n\t" \
1240 "vrgather.vv v9, v1, v13 \n\t" \
1241 "vadd.vi v13, v13, 4 \n\t" \
1242 "vrgather.vv v10, v1, v13 \n\t" \
1243 "vadd.vi v13, v13, 4 \n\t" \
1244 "vrgather.vv v11, v1, v13 \n\t" \
1245 "vadd.vi v13, v13, -12 \n\t"
1246
1247// using for M4Kernel
1248#define LOAD_B_16x8x2 \
1249 "vsetvli t0, zero, e8, m1 \n\t" \
1250 "vle8.v v6, (s1) \n\t" \
1251 "addi s1, s1, 32*4 \n\t" \
1252 "vle8.v v7, (s2) \n\t" \
1253 "addi s2, s2, 32*4 \n\t" \
1254 "vle8.v v8, (s3) \n\t" \
1255 "addi s3, s3, 32*4 \n\t" \
1256 "vle8.v v9, (s4) \n\t" \
1257 "addi s4, s4, 32*4 \n\t" \
1258 \
1259 "vand.vi v2, v6, 15 \n\t" \
1260 "vand.vi v3, v7, 15 \n\t" \
1261 "vand.vi v4, v8, 15 \n\t" \
1262 "vand.vi v5, v9, 15 \n\t" \
1263 \
1264 "vsrl.vi v6, v6, 4 \n\t" \
1265 "vsrl.vi v7, v7, 4 \n\t" \
1266 "vsrl.vi v8, v8, 4 \n\t" \
1267 "vsrl.vi v9, v9, 4 \n\t"
1268
1269// [s2|s5, s3, s4, s6]
1270#define LOAD_SCALE_4x16_FP16 \
1271 "addi s2, s5, -8 \n\t" \
1272 "addi s3, s5, 8 \n\t" \
1273 "addi s4, s5, 16 \n\t" \
1274 "addi s6, s5, 24 \n\t" \
1275 "li t1, 0xf0 \n\t" \
1276 "vmv.s.x v0, t1 \n\t" \
1277 "vsetvli t0, zero, e16, mf4 \n\t" \
1278 "vle16.v v9, (s5) \n\t" \
1279 "vle16.v v11, (s3) \n\t" \
1280 "vle16.v v13, (s4) \n\t" \
1281 "vle16.v v15, (s6) \n\t" \
1282 "vsetvli t0, zero, e16, mf2 \n\t" \
1283 "vle16.v v9, (s2), v0.t \n\t" \
1284 "vle16.v v11, (s5), v0.t \n\t" \
1285 "vle16.v v13, (s3), v0.t \n\t" \
1286 "vle16.v v15, (s4), v0.t \n\t" \
1287 "vfwcvt.f.f.v v8, v9 \n\t" \
1288 "vfwcvt.f.f.v v10, v11 \n\t" \
1289 "vfwcvt.f.f.v v12, v13 \n\t" \
1290 "vfwcvt.f.f.v v14, v15 \n\t" \
1291 "vsetvli t0, zero, e32, m1 \n\t" \
1292 "vmv.v.v v9, v8 \n\t" \
1293 "vmv.v.v v11, v10 \n\t" \
1294 "vmv.v.v v13, v12 \n\t" \
1295 "vmv.v.v v15, v14 \n\t" \
1296 "li t1, 0xf0 \n\t" \
1297 "vmv.s.x v0, t1 \n\t" \
1298 "vsetvli t0, zero, e32, mf2 \n\t" \
1299 "vfmul.vf v8, v8, f1 \n\t" \
1300 "vfmul.vf v10, v10, f1 \n\t" \
1301 "vfmul.vf v12, v12, f1 \n\t" \
1302 "vfmul.vf v14, v14, f1 \n\t" \
1303 "vfmul.vf v9, v9, f3 \n\t" \
1304 "vfmul.vf v11, v11, f3 \n\t" \
1305 "vfmul.vf v13, v13, f3 \n\t" \
1306 "vfmul.vf v15, v15, f3 \n\t" \
1307 "vsetvli t0, zero, e32, m1 \n\t" \
1308 "vfmul.vf v8, v8, f2, v0.t \n\t" \
1309 "vfmul.vf v10, v10, f2, v0.t \n\t" \
1310 "vfmul.vf v12, v12, f2, v0.t \n\t" \
1311 "vfmul.vf v14, v14, f2, v0.t \n\t" \
1312 "vfmul.vf v9, v9, f4, v0.t \n\t" \
1313 "vfmul.vf v11, v11, f4, v0.t \n\t" \
1314 "vfmul.vf v13, v13, f4, v0.t \n\t" \
1315 "vfmul.vf v15, v15, f4, v0.t \n\t"
1316
1317// [s2|s5, s3, s4, s6]
1318#define LOAD_SCALE_4x16 \
1319 "addi s2, s5, -16 \n\t" \
1320 "addi s3, s5, 16 \n\t" \
1321 "addi s4, s5, 32 \n\t" \
1322 "addi s6, s5, 48 \n\t" \
1323 "li t1, 0xf0 \n\t" \
1324 "vmv.s.x v0, t1 \n\t" \
1325 "vsetvli t0, zero, e32, mf2 \n\t" \
1326 "vle32.v v8, (s5) \n\t" \
1327 "vle32.v v10, (s3) \n\t" \
1328 "vle32.v v12, (s4) \n\t" \
1329 "vle32.v v14, (s6) \n\t" \
1330 "vsetvli t0, zero, e32, m1 \n\t" \
1331 "vle32.v v8, (s2), v0.t \n\t" \
1332 "vle32.v v10, (s5), v0.t \n\t" \
1333 "vle32.v v12, (s3), v0.t \n\t" \
1334 "vle32.v v14, (s4), v0.t \n\t" \
1335 "vmv.v.v v9, v8 \n\t" \
1336 "vmv.v.v v11, v10 \n\t" \
1337 "vmv.v.v v13, v12 \n\t" \
1338 "vmv.v.v v15, v14 \n\t" \
1339 "vsetvli t0, zero, e32, mf2 \n\t" \
1340 "vfmul.vf v8, v8, f1 \n\t" \
1341 "vfmul.vf v10, v10, f1 \n\t" \
1342 "vfmul.vf v12, v12, f1 \n\t" \
1343 "vfmul.vf v14, v14, f1 \n\t" \
1344 "vfmul.vf v9, v9, f3 \n\t" \
1345 "vfmul.vf v11, v11, f3 \n\t" \
1346 "vfmul.vf v13, v13, f3 \n\t" \
1347 "vfmul.vf v15, v15, f3 \n\t" \
1348 "vsetvli t0, zero, e32, m1 \n\t" \
1349 "vfmul.vf v8, v8, f2, v0.t \n\t" \
1350 "vfmul.vf v10, v10, f2, v0.t \n\t" \
1351 "vfmul.vf v12, v12, f2, v0.t \n\t" \
1352 "vfmul.vf v14, v14, f2, v0.t \n\t" \
1353 "vfmul.vf v9, v9, f4, v0.t \n\t" \
1354 "vfmul.vf v11, v11, f4, v0.t \n\t" \
1355 "vfmul.vf v13, v13, f4, v0.t \n\t" \
1356 "vfmul.vf v15, v15, f4, v0.t \n\t"
1357
1358//[s1| BIAS, s2, s3, s4]
1359#define LOAD_BIAS \
1360 "vsetvli t0, zero, e32, mf2 \n\t" \
1361 "li t1, 0xf0 \n\t" \
1362 "vmv.s.x v0, t1 \n\t" \
1363 "addi s1, %[BIAS], -16 \n\t" \
1364 "addi s2, %[BIAS], 16 \n\t" \
1365 "addi s3, %[BIAS], 32 \n\t" \
1366 "addi s4, %[BIAS], 48 \n\t" \
1367 \
1368 "vle32.v v24, (%[BIAS]) \n\t" \
1369 "vle32.v v26, (s2) \n\t" \
1370 "vle32.v v28, (s3) \n\t" \
1371 "vle32.v v30, (s4) \n\t" \
1372 "vsetvli t0, zero, e32, m1 \n\t" \
1373 "vle32.v v24, (s1), v0.t \n\t" \
1374 "vle32.v v26, (%[BIAS]), v0.t \n\t" \
1375 "vle32.v v28, (s2), v0.t \n\t" \
1376 "vle32.v v30, (s3), v0.t \n\t" \
1377 "vmv.v.v v25, v24 \n\t" \
1378 "vmv.v.v v27, v26 \n\t" \
1379 "vmv.v.v v29, v28 \n\t" \
1380 "vmv.v.v v31, v30 \n\t"
1381
1382#define SQ4BIT_KERNEL_COMP_4x16x16 \
1383 "vmadot v16, v10, v2 \n\t" \
1384 "vmadot v18, v10, v3 \n\t" \
1385 "vmadot v20, v10, v4 \n\t" \
1386 "vmadot v22, v10, v5 \n\t" \
1387 "vmadot v16, v11, v6 \n\t" \
1388 "vmadot v18, v11, v7 \n\t" \
1389 "vmadot v20, v11, v8 \n\t" \
1390 "vmadot v22, v11, v9 \n\t"
1391
1392#define SAVE_RESULT_4x16 \
1393 "addi a1, %[C], 0 \n\t" \
1394 "add a2, %[C], %[LDC] \n\t" \
1395 "add a3, a2, %[LDC] \n\t" \
1396 "add a4, a3, %[LDC] \n\t" \
1397 "addi a2, a2, -16 \n\t" \
1398 "addi a4, a4, -16 \n\t" \
1399 "li t1, 0xf0 \n\t" \
1400 "vmv.s.x v0, t1 \n\t" \
1401 "vsetvli t0, zero, e32, mf2 \n\t" \
1402 \
1403 "vse32.v v24, (a1) \n\t" \
1404 "addi a1, a1, 16 \n\t" \
1405 "vse32.v v25, (a3) \n\t" \
1406 "addi a3, a3, 16 \n\t" \
1407 \
1408 "vse32.v v26, (a1) \n\t" \
1409 "addi a1, a1, 16 \n\t" \
1410 "vse32.v v27, (a3) \n\t" \
1411 "addi a3, a3, 16 \n\t" \
1412 \
1413 "vse32.v v28, (a1) \n\t" \
1414 "addi a1, a1, 16 \n\t" \
1415 "vse32.v v29, (a3) \n\t" \
1416 "addi a3, a3, 16 \n\t" \
1417 \
1418 "vse32.v v30, (a1) \n\t" \
1419 "vse32.v v31, (a3) \n\t" \
1420 "vsetvli t0, zero, e32, m1 \n\t" \
1421 \
1422 "vse32.v v24, (a2), v0.t \n\t" \
1423 "addi a2, a2, 16 \n\t" \
1424 "vse32.v v25, (a4), v0.t \n\t" \
1425 "addi a4, a4, 16 \n\t" \
1426 \
1427 "vse32.v v26, (a2), v0.t \n\t" \
1428 "addi a2, a2, 16 \n\t" \
1429 "vse32.v v27, (a4), v0.t \n\t" \
1430 "addi a4, a4, 16 \n\t" \
1431 \
1432 "vse32.v v28, (a2), v0.t \n\t" \
1433 "addi a2, a2, 16 \n\t" \
1434 "vse32.v v29, (a4), v0.t \n\t" \
1435 "addi a4, a4, 16 \n\t" \
1436 \
1437 "vse32.v v30, (a2), v0.t \n\t" \
1438 "vse32.v v31, (a4), v0.t \n\t"
1439
1440#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
1441 "vsetvli t0, zero, e8, mf2 \n\t" \
1442 "vle8.v v11, (s6) \n\t" \
1443 "vsetvli t0, zero, e8, m1 \n\t" \
1444 "vrgather.vv v12, v11, v1 \n\t" \
1445 "vadd.vi v1, v1, 4 \n\t" \
1446 "vrgather.vv v13, v11, v1 \n\t" \
1447 "vadd.vi v1, v1, 4 \n\t" \
1448 "vrgather.vv v14, v11, v1 \n\t" \
1449 "vadd.vi v1, v1, 4 \n\t" \
1450 "vrgather.vv v15, v11, v1 \n\t" \
1451 "vadd.vi v1, v1, -12 \n\t"
1452
1453template <bool HasZeroPoint>
1454void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
1455 const std::byte * QuantA,
1456 const std::byte * QuantBData,
1457 const float * QuantBScale,
1458 const std::byte * QuantBZeroPoint,
1459 float * C,
1460 size_t CountN,
1461 size_t BlockCountK,
1462 const float * Bias,
1463 const size_t ldc) {
1464 GGML_UNUSED(QuantBScale);
1465 GGML_UNUSED(QuantBZeroPoint);
1466 size_t LDC = ldc * sizeof(float);
1467 const size_t INNER = BlkLen / 16;
1468 float tmp[4 * 16];
1469
1470 if constexpr (HasZeroPoint) {
1471 for (size_t n = 0; n < CountN; n += 16) {
1472 size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1473 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1474 n * BlockCountK * BlkLen / 2 + // b data
1475 n * BlockCountK * sizeof(uint8_t) + // zp
1476 n * BlockCountK * sizeof(_Float16); // scale
1477 float * CPtr = C + n;
1478 if (NBLKS < 16) {
1479 CPtr = tmp;
1480 LDC = 16 * sizeof(float);
1481 }
1482 if (Bias != nullptr) {
1483 const float * bias = Bias + n;
1484 if (NBLKS < 16) {
1485 __asm__ volatile(
1486 "vsetvli t0, %[N], e32, m2 \n\t"
1487 "vle32.v v0, (%[SRC]) \n\t"
1488 "vse32.v v0, (%[DST]) \n\t"
1489 :
1490 : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1491 : "cc", "t0");
1492 bias = tmp;
1493 }
1494 __asm__ volatile(LOAD_BIAS
1495
1496 "addi t3, %[BlockCountK], 0 \n\t"
1497
1498 "vsetvli t0, zero, e8, m1 \n\t"
1499 "li s1, 24 \n\t"
1500 "vmv.v.i v1, 3 \n\t"
1501 "vsetvli t0, s1, e8, m1 \n\t"
1502 "vmv.v.i v1, 2 \n\t"
1503 "vsetvli t0, zero, e8, mf2 \n\t"
1504 "vmv.v.i v1, 1 \n\t"
1505 "vsetvli t0, zero, e8, mf4 \n\t"
1506 "vmv.v.i v1, 0 \n\t"
1507
1508 "addi a1, %[A], 0 \n\t"
1509 "addi s1, %[B], 0 \n\t"
1510
1511 "BLOCK_COUNTK_LOOP%=: \n\t"
1512 // scale offset
1513 "addi s5, s1, 0 \n\t"
1514 // zp offset
1515 "addi s6, s1, 32 \n\t"
1516 "addi s1, s6, 16 \n\t"
1517 "addi s2, s1, 32 \n\t"
1518 "addi s3, s1, 32*2 \n\t"
1519 "addi s4, s1, 32*3 \n\t"
1520
1521 "vsetvli t0, zero, e32, m8 \n\t"
1522 "vxor.vv v16, v16, v16 \n\t"
1523 // load a scale
1524 "flw f1, (a1) \n\t"
1525 "flw f2, 4(a1) \n\t"
1526 "flw f3, 8(a1) \n\t"
1527 "flw f4, 12(a1) \n\t"
1528 "addi a1, a1, 16 \n\t"
1529 "addi t2, %[INNER], 0 \n\t"
1530
1531 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1532
1533 "BLOCK_INNER_LOOP%=: \n\t"
1534
1535 LOAD_B_16x8x2
1536
1537 "vle8.v v10, (a1) \n\t"
1538 "addi a1, a1, 32 \n\t"
1539 "vle8.v v11, (a1) \n\t"
1540 "addi a1, a1, 32 \n\t"
1541 "vsub.vv v2, v2, v12 \n\t"
1542 "vsub.vv v6, v6, v12 \n\t"
1543 "vsub.vv v3, v3, v13 \n\t"
1544 "vsub.vv v7, v7, v13 \n\t"
1545 "vsub.vv v4, v4, v14 \n\t"
1546 "vsub.vv v8, v8, v14 \n\t"
1547 "vsub.vv v5, v5, v15 \n\t"
1548 "vsub.vv v9, v9, v15 \n\t"
1549
1550 SQ4BIT_KERNEL_COMP_4x16x16
1551
1552 "addi t2, t2, -1 \n\t"
1553 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1554
1555 LOAD_SCALE_4x16_FP16
1556
1557 "vsetvli t0, zero, e32, m8 \n\t"
1558 "vfcvt.f.x.v v16, v16 \n\t"
1559 "vfmacc.vv v24, v16, v8 \n\t"
1560 "addi t3, t3, -1 \n\t"
1561 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1562
1563 "RESULT_SAVE%=: \n\t"
1564
1565 SAVE_RESULT_4x16
1566
1567 :
1568 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1569 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1570 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1571 "s2", "s3", "s4", "s5", "s6");
1572
1573 } else {
1574 __asm__ volatile(
1575 "vsetvli t0, zero, e32, m8 \n\t"
1576 "vxor.vv v24, v24, v24 \n\t"
1577 "addi t3, %[BlockCountK], 0 \n\t"
1578 "vsetvli t0, zero, e8, m1 \n\t"
1579 "li s1, 24 \n\t"
1580 "vmv.v.i v1, 3 \n\t"
1581 "vsetvli t0, s1, e8, m1 \n\t"
1582 "vmv.v.i v1, 2 \n\t"
1583 "vsetvli t0, zero, e8, mf2 \n\t"
1584 "vmv.v.i v1, 1 \n\t"
1585 "vsetvli t0, zero, e8, mf4 \n\t"
1586 "vmv.v.i v1, 0 \n\t"
1587 "addi a1, %[A], 0 \n\t"
1588 "addi s1, %[B], 0 \n\t"
1589 "BLOCK_COUNTK_LOOP%=: \n\t"
1590 // scale offset
1591 "addi s5, s1, 0 \n\t"
1592 // zp offset
1593 "addi s6, s1, 32 \n\t"
1594 "addi s1, s6, 16 \n\t"
1595 "addi s2, s1, 32 \n\t"
1596 "addi s3, s1, 32*2 \n\t"
1597 "addi s4, s1, 32*3 \n\t"
1598
1599 "vsetvli t0, zero, e32, m8 \n\t"
1600 "vxor.vv v16, v16, v16 \n\t"
1601 // load a scale
1602 "flw f1, (a1) \n\t"
1603 "flw f2, 4(a1) \n\t"
1604 "flw f3, 8(a1) \n\t"
1605 "flw f4, 12(a1) \n\t"
1606 "addi a1, a1, 16 \n\t"
1607 "addi t2, %[INNER], 0 \n\t"
1608
1609 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1610
1611 "BLOCK_INNER_LOOP%=: \n\t"
1612
1613 LOAD_B_16x8x2
1614
1615 "vle8.v v10, (a1) \n\t"
1616 "addi a1, a1, 32 \n\t"
1617 "vle8.v v11, (a1) \n\t"
1618 "addi a1, a1, 32 \n\t"
1619 "vsub.vv v2, v2, v12 \n\t"
1620 "vsub.vv v6, v6, v12 \n\t"
1621 "vsub.vv v3, v3, v13 \n\t"
1622 "vsub.vv v7, v7, v13 \n\t"
1623 "vsub.vv v4, v4, v14 \n\t"
1624 "vsub.vv v8, v8, v14 \n\t"
1625 "vsub.vv v5, v5, v15 \n\t"
1626 "vsub.vv v9, v9, v15 \n\t"
1627
1628 SQ4BIT_KERNEL_COMP_4x16x16
1629
1630 "addi t2, t2, -1 \n\t"
1631 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1632
1633 LOAD_SCALE_4x16_FP16
1634
1635 "vsetvli t0, zero, e32, m8 \n\t"
1636 "vfcvt.f.x.v v16, v16 \n\t"
1637 "vfmacc.vv v24, v16, v8 \n\t"
1638 "addi t3, t3, -1 \n\t"
1639 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1640
1641 "RESULT_SAVE%=: \n\t"
1642
1643 SAVE_RESULT_4x16
1644
1645 :
1646 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1647 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1648 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1649 "s4", "s5", "s6");
1650 }
1651 }
1652 } else {
1653 for (size_t n = 0; n < CountN; n += 16) {
1654 size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1655 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1656 n * BlockCountK * BlkLen / 2 + // b data
1657 n * BlockCountK * sizeof(_Float16); // scale
1658 float * CPtr = C + n;
1659 if (NBLKS < 16) {
1660 CPtr = tmp;
1661 LDC = 16 * sizeof(float);
1662 }
1663 if (Bias != nullptr) {
1664 const float * bias = Bias + n;
1665 if (NBLKS < 16) {
1666 __asm__ volatile(
1667 "vsetvli t0, %[N], e32, m2 \n\t"
1668 "vle32.v v0, (%[SRC]) \n\t"
1669 "vse32.v v0, (%[DST]) \n\t"
1670 :
1671 : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1672 : "cc", "t0");
1673 bias = tmp;
1674 }
1675 __asm__ volatile(LOAD_BIAS
1676
1677 "addi t3, %[BlockCountK], 0 \n\t"
1678 "addi a1, %[A], 0 \n\t"
1679 "addi s1, %[B], 0 \n\t"
1680 "BLOCK_COUNTK_LOOP%=: \n\t"
1681 "addi s5, s1, 0 \n\t"
1682 "addi s1, s5, 32 \n\t"
1683 "addi s2, s1, 32 \n\t"
1684 "addi s3, s1, 32*2 \n\t"
1685 "addi s4, s1, 32*3 \n\t"
1686 "vsetvli t0, zero, e32, m8 \n\t"
1687 "vxor.vv v16, v16, v16 \n\t"
1688 // load a scale
1689 "flw f1, (a1) \n\t"
1690 "flw f2, 4(a1) \n\t"
1691 "flw f3, 8(a1) \n\t"
1692 "flw f4, 12(a1) \n\t"
1693 "addi a1, a1, 16 \n\t"
1694 "addi t2, %[INNER], 0 \n\t"
1695 "BLOCK_INNER_LOOP%=: \n\t"
1696
1697 LOAD_B_16x8x2
1698
1699 "vsetvli t0, zero, e8, m1 \n\t"
1700 "vle8.v v10, (a1) \n\t"
1701 "addi a1, a1, 32 \n\t"
1702 "vle8.v v11, (a1) \n\t"
1703 "addi a1, a1, 32 \n\t"
1704 "vadd.vi v2, v2, -8 \n\t"
1705 "vadd.vi v3, v3, -8 \n\t"
1706 "vadd.vi v4, v4, -8 \n\t"
1707 "vadd.vi v5, v5, -8 \n\t"
1708 "vadd.vi v6, v6, -8 \n\t"
1709 "vadd.vi v7, v7, -8 \n\t"
1710 "vadd.vi v8, v8, -8 \n\t"
1711 "vadd.vi v9, v9, -8 \n\t"
1712
1713 SQ4BIT_KERNEL_COMP_4x16x16
1714
1715 "addi t2, t2, -1 \n\t"
1716 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1717
1718 LOAD_SCALE_4x16_FP16
1719
1720 "vsetvli t0, zero, e32, m8 \n\t"
1721 "vfcvt.f.x.v v16, v16 \n\t"
1722 "vfmacc.vv v24, v16, v8 \n\t"
1723 "addi t3, t3, -1 \n\t"
1724 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1725 "RESULT_SAVE%=: \n\t"
1726
1727 SAVE_RESULT_4x16
1728
1729 :
1730 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1731 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1732 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1733 "s2", "s3", "s4", "s5", "s6");
1734
1735 } else {
1736 __asm__ volatile(
1737 "vsetvli t0, zero, e32, m8 \n\t"
1738 "vxor.vv v24, v24, v24 \n\t"
1739 "addi t3, %[BlockCountK], 0 \n\t"
1740 "addi a1, %[A], 0 \n\t"
1741 "addi s1, %[B], 0 \n\t"
1742 "BLOCK_COUNTK_LOOP%=: \n\t"
1743 "addi s5, s1, 0 \n\t"
1744 "addi s1, s5, 32 \n\t"
1745 "addi s2, s1, 32 \n\t"
1746 "addi s3, s1, 32*2 \n\t"
1747 "addi s4, s1, 32*3 \n\t"
1748 "vsetvli t0, zero, e32, m8 \n\t"
1749 "vxor.vv v16, v16, v16 \n\t"
1750 // load a scale
1751 "flw f1, (a1) \n\t"
1752 "flw f2, 4(a1) \n\t"
1753 "flw f3, 8(a1) \n\t"
1754 "flw f4, 12(a1) \n\t"
1755 "addi a1, a1, 16 \n\t"
1756 "addi t2, %[INNER], 0 \n\t"
1757 "BLOCK_INNER_LOOP%=: \n\t"
1758
1759 LOAD_B_16x8x2
1760
1761 "vsetvli t0, zero, e8, m1 \n\t"
1762 "vle8.v v10, (a1) \n\t"
1763 "addi a1, a1, 32 \n\t"
1764 "vle8.v v11, (a1) \n\t"
1765 "addi a1, a1, 32 \n\t"
1766 "vadd.vi v2, v2, -8 \n\t"
1767 "vadd.vi v3, v3, -8 \n\t"
1768 "vadd.vi v4, v4, -8 \n\t"
1769 "vadd.vi v5, v5, -8 \n\t"
1770 "vadd.vi v6, v6, -8 \n\t"
1771 "vadd.vi v7, v7, -8 \n\t"
1772 "vadd.vi v8, v8, -8 \n\t"
1773 "vadd.vi v9, v9, -8 \n\t"
1774
1775 SQ4BIT_KERNEL_COMP_4x16x16
1776
1777 "addi t2, t2, -1 \n\t"
1778 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1779
1780 LOAD_SCALE_4x16_FP16
1781
1782 "vsetvli t0, zero, e32, m8 \n\t"
1783 "vfcvt.f.x.v v16, v16 \n\t"
1784 "vfmacc.vv v24, v16, v8 \n\t"
1785 "addi t3, t3, -1 \n\t"
1786 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1787 "RESULT_SAVE%=: \n\t"
1788
1789 SAVE_RESULT_4x16
1790
1791 :
1792 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1793 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1794 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1795 "s4", "s5", "s6");
1796 }
1797 }
1798 }
1799 if (CountN % 16 != 0) {
1800 // stroe output from tmp to C when NBLKS less than 16.
1801 float * CPtr = C + CountN / 16 * 16;
1802 const size_t N = CountN % 16;
1803 LDC = ldc * sizeof(float);
1804 __asm__ volatile(
1805 "vsetvli t0, %[N], e32, m2 \n\t"
1806 "vle32.v v0, (%[SRC]) \n\t"
1807 "addi s2, %[SRC], 64 \n\t"
1808 "addi s3, %[SRC], 64*2 \n\t"
1809 "addi s4, %[SRC], 64*3 \n\t"
1810 "vle32.v v2, (s2) \n\t"
1811 "vle32.v v4, (s3) \n\t"
1812 "vle32.v v6, (s4) \n\t"
1813 "add t2, %[DST], %[LDC] \n\t"
1814 "add t3, t2, %[LDC] \n\t"
1815 "add t4, t3, %[LDC] \n\t"
1816 "vse32.v v0, (%[DST]) \n\t"
1817 "vse32.v v2, (t2) \n\t"
1818 "vse32.v v4, (t3) \n\t"
1819 "vse32.v v6, (t4) \n\t"
1820 :
1821 : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
1822 : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
1823 }
1824}
1825
1826template <bool HasZeroPoint>
1827void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
1828 const std::byte * QuantA,
1829 const std::byte * QuantBData,
1830 const float * QuantBScale,
1831 const std::byte * QuantBZeroPoint,
1832 float * C,
1833 size_t CountN,
1834 size_t BlockCountK,
1835 const float * Bias,
1836 const size_t ldc) {
1837 GGML_UNUSED(QuantBScale);
1838 GGML_UNUSED(QuantBZeroPoint);
1839 size_t LDC = ldc * sizeof(float);
1840 const size_t INNER = BlkLen / 16;
1841 float tmp[4 * 16];
1842
1843 if constexpr (HasZeroPoint) {
1844 for (size_t n = 0; n < CountN; n += 16) {
1845 size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1846 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1847 n * BlockCountK * BlkLen / 2 + // b data
1848 n * BlockCountK * sizeof(uint8_t) + // zp
1849 n * BlockCountK * sizeof(float); // scale
1850 float * CPtr = C + n;
1851 if (NBLKS < 16) {
1852 CPtr = tmp;
1853 LDC = 16 * sizeof(float);
1854 }
1855 if (Bias != nullptr) {
1856 const float * bias = Bias + n;
1857 if (NBLKS < 16) {
1858 __asm__ volatile(
1859 "vsetvli t0, %[N], e32, m2 \n\t"
1860 "vle32.v v0, (%[SRC]) \n\t"
1861 "vse32.v v0, (%[DST]) \n\t"
1862 :
1863 : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1864 : "cc", "t0");
1865 bias = tmp;
1866 }
1867
1868 __asm__ volatile(LOAD_BIAS
1869 "addi t3, %[BlockCountK], 0 \n\t"
1870 "vsetvli t0, zero, e8, m1 \n\t"
1871 "li s1, 24 \n\t"
1872 "vmv.v.i v1, 3 \n\t"
1873 "vsetvli t0, s1, e8, m1 \n\t"
1874 "vmv.v.i v1, 2 \n\t"
1875 "vsetvli t0, zero, e8, mf2 \n\t"
1876 "vmv.v.i v1, 1 \n\t"
1877 "vsetvli t0, zero, e8, mf4 \n\t"
1878 "vmv.v.i v1, 0 \n\t"
1879 "addi a1, %[A], 0 \n\t"
1880 "addi s1, %[B], 0 \n\t"
1881 "BLOCK_COUNTK_LOOP%=: \n\t"
1882 // scale offset
1883 "addi s5, s1, 0 \n\t"
1884 // zp offset
1885 "addi s6, s1, 64 \n\t"
1886 "addi s1, s6, 16 \n\t"
1887 "addi s2, s1, 32 \n\t"
1888 "addi s3, s1, 32*2 \n\t"
1889 "addi s4, s1, 32*3 \n\t"
1890 "vsetvli t0, zero, e32, m8 \n\t"
1891 "vxor.vv v16, v16, v16 \n\t"
1892 // load a scale
1893 "flw f1, (a1) \n\t"
1894 "flw f2, 4(a1) \n\t"
1895 "flw f3, 8(a1) \n\t"
1896 "flw f4, 12(a1) \n\t"
1897 "addi a1, a1, 16 \n\t"
1898 "addi t2, %[INNER], 0 \n\t"
1899
1900 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1901
1902 "BLOCK_INNER_LOOP%=: \n\t"
1903
1904 LOAD_B_16x8x2
1905
1906 "vle8.v v10, (a1) \n\t"
1907 "addi a1, a1, 32 \n\t"
1908 "vle8.v v11, (a1) \n\t"
1909 "addi a1, a1, 32 \n\t"
1910 "vsub.vv v2, v2, v12 \n\t"
1911 "vsub.vv v6, v6, v12 \n\t"
1912 "vsub.vv v3, v3, v13 \n\t"
1913 "vsub.vv v7, v7, v13 \n\t"
1914 "vsub.vv v4, v4, v14 \n\t"
1915 "vsub.vv v8, v8, v14 \n\t"
1916 "vsub.vv v5, v5, v15 \n\t"
1917 "vsub.vv v9, v9, v15 \n\t"
1918
1919 SQ4BIT_KERNEL_COMP_4x16x16
1920
1921 "addi t2, t2, -1 \n\t"
1922 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1923
1924 LOAD_SCALE_4x16
1925
1926 "vsetvli t0, zero, e32, m8 \n\t"
1927 "vfcvt.f.x.v v16, v16 \n\t"
1928 "vfmacc.vv v24, v16, v8 \n\t"
1929 "addi t3, t3, -1 \n\t"
1930 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1931
1932 "RESULT_SAVE%=: \n\t"
1933
1934 SAVE_RESULT_4x16
1935
1936 :
1937 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1938 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1939 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1940 "s2", "s3", "s4", "s5", "s6");
1941
1942 } else {
1943 __asm__ volatile(
1944 "vsetvli t0, zero, e32, m8 \n\t"
1945 "vxor.vv v24, v24, v24 \n\t"
1946 "addi t3, %[BlockCountK], 0 \n\t"
1947 "vsetvli t0, zero, e8, m1 \n\t"
1948 "li s1, 24 \n\t"
1949 "vmv.v.i v1, 3 \n\t"
1950 "vsetvli t0, s1, e8, m1 \n\t"
1951 "vmv.v.i v1, 2 \n\t"
1952 "vsetvli t0, zero, e8, mf2 \n\t"
1953 "vmv.v.i v1, 1 \n\t"
1954 "vsetvli t0, zero, e8, mf4 \n\t"
1955 "vmv.v.i v1, 0 \n\t"
1956 "addi a1, %[A], 0 \n\t"
1957 "addi s1, %[B], 0 \n\t"
1958 "BLOCK_COUNTK_LOOP%=: \n\t"
1959 // scale offset
1960 "addi s5, s1, 0 \n\t"
1961 // zp offset
1962 "addi s6, s1, 64 \n\t"
1963 "addi s1, s6, 16 \n\t"
1964 "addi s2, s1, 32 \n\t"
1965 "addi s3, s1, 32*2 \n\t"
1966 "addi s4, s1, 32*3 \n\t"
1967 "vsetvli t0, zero, e32, m8 \n\t"
1968 "vxor.vv v16, v16, v16 \n\t"
1969 // load a scale
1970 // load a scale
1971 "flw f1, (a1) \n\t"
1972 "flw f2, 4(a1) \n\t"
1973 "flw f3, 8(a1) \n\t"
1974 "flw f4, 12(a1) \n\t"
1975 "addi a1, a1, 16 \n\t"
1976 "addi t2, %[INNER], 0 \n\t"
1977
1978 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1979
1980 "BLOCK_INNER_LOOP%=: \n\t"
1981
1982 LOAD_B_16x8x2
1983
1984 "vle8.v v10, (a1) \n\t"
1985 "addi a1, a1, 32 \n\t"
1986 "vle8.v v11, (a1) \n\t"
1987 "addi a1, a1, 32 \n\t"
1988 "vsub.vv v2, v2, v12 \n\t"
1989 "vsub.vv v6, v6, v12 \n\t"
1990 "vsub.vv v3, v3, v13 \n\t"
1991 "vsub.vv v7, v7, v13 \n\t"
1992 "vsub.vv v4, v4, v14 \n\t"
1993 "vsub.vv v8, v8, v14 \n\t"
1994 "vsub.vv v5, v5, v15 \n\t"
1995 "vsub.vv v9, v9, v15 \n\t"
1996
1997 SQ4BIT_KERNEL_COMP_4x16x16
1998
1999 "addi t2, t2, -1 \n\t"
2000 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2001
2002 LOAD_SCALE_4x16
2003
2004 "vsetvli t0, zero, e32, m8 \n\t"
2005 "vfcvt.f.x.v v16, v16 \n\t"
2006 "vfmacc.vv v24, v16, v8 \n\t"
2007 "addi t3, t3, -1 \n\t"
2008 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2009
2010 "RESULT_SAVE%=: \n\t"
2011
2012 SAVE_RESULT_4x16
2013
2014 :
2015 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2016 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2017 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2018 "s4", "s5", "s6");
2019 }
2020 }
2021 } else {
2022 for (size_t n = 0; n < CountN; n += 16) {
2023 size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
2024 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2025 n * BlockCountK * BlkLen / 2 + // b data
2026 n * BlockCountK * sizeof(float); // scale
2027 float * CPtr = C + n;
2028 if (NBLKS < 16) {
2029 CPtr = tmp;
2030 LDC = 16 * sizeof(float);
2031 }
2032 if (Bias != nullptr) {
2033 const float * bias = Bias + n;
2034 if (NBLKS < 16) {
2035 __asm__ volatile(
2036 "vsetvli t0, %[N], e32, m2 \n\t"
2037 "vle32.v v0, (%[SRC]) \n\t"
2038 "vse32.v v0, (%[DST]) \n\t"
2039 :
2040 : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
2041 : "cc", "t0");
2042 bias = tmp;
2043 }
2044 __asm__ volatile(LOAD_BIAS
2045 "addi t3, %[BlockCountK], 0 \n\t"
2046 "addi a1, %[A], 0 \n\t"
2047 "addi s1, %[B], 0 \n\t"
2048 "BLOCK_COUNTK_LOOP%=: \n\t"
2049 "addi s5, s1, 0 \n\t"
2050 "addi s1, s5, 64 \n\t"
2051 "addi s2, s1, 32 \n\t"
2052 "addi s3, s1, 32*2 \n\t"
2053 "addi s4, s1, 32*3 \n\t"
2054 "vsetvli t0, zero, e32, m8 \n\t"
2055 "vxor.vv v16, v16, v16 \n\t"
2056 // load a scale
2057 "flw f1, (a1) \n\t"
2058 "flw f2, 4(a1) \n\t"
2059 "flw f3, 8(a1) \n\t"
2060 "flw f4, 12(a1) \n\t"
2061 "addi a1, a1, 16 \n\t"
2062 "addi t2, %[INNER], 0 \n\t"
2063 "BLOCK_INNER_LOOP%=: \n\t"
2064
2065 LOAD_B_16x8x2
2066
2067 "vsetvli t0, zero, e8, m1 \n\t"
2068 "vle8.v v10, (a1) \n\t"
2069 "addi a1, a1, 32 \n\t"
2070 "vle8.v v11, (a1) \n\t"
2071 "addi a1, a1, 32 \n\t"
2072 "vadd.vi v2, v2, -8 \n\t"
2073 "vadd.vi v3, v3, -8 \n\t"
2074 "vadd.vi v4, v4, -8 \n\t"
2075 "vadd.vi v5, v5, -8 \n\t"
2076 "vadd.vi v6, v6, -8 \n\t"
2077 "vadd.vi v7, v7, -8 \n\t"
2078 "vadd.vi v8, v8, -8 \n\t"
2079 "vadd.vi v9, v9, -8 \n\t"
2080
2081 SQ4BIT_KERNEL_COMP_4x16x16
2082
2083 "addi t2, t2, -1 \n\t"
2084 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2085
2086 LOAD_SCALE_4x16
2087
2088 "vsetvli t0, zero, e32, m8 \n\t"
2089 "vfcvt.f.x.v v16, v16 \n\t"
2090 "vfmacc.vv v24, v16, v8 \n\t"
2091 "addi t3, t3, -1 \n\t"
2092 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2093
2094 "RESULT_SAVE%=: \n\t"
2095
2096 SAVE_RESULT_4x16
2097
2098 :
2099 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2100 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
2101 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
2102 "s2", "s3", "s4", "s5", "s6");
2103
2104 } else {
2105 __asm__ volatile(
2106 "vsetvli t0, zero, e32, m8 \n\t"
2107 "vxor.vv v24, v24, v24 \n\t"
2108 "addi t3, %[BlockCountK], 0 \n\t"
2109 "addi a1, %[A], 0 \n\t"
2110 "addi s1, %[B], 0 \n\t"
2111 "BLOCK_COUNTK_LOOP%=: \n\t"
2112 "addi s5, s1, 0 \n\t"
2113 "addi s1, s5, 64 \n\t"
2114 "addi s2, s1, 32 \n\t"
2115 "addi s3, s1, 32*2 \n\t"
2116 "addi s4, s1, 32*3 \n\t"
2117 "vsetvli t0, zero, e32, m8 \n\t"
2118 "vxor.vv v16, v16, v16 \n\t"
2119 // load a scale
2120 "flw f1, (a1) \n\t"
2121 "flw f2, 4(a1) \n\t"
2122 "flw f3, 8(a1) \n\t"
2123 "flw f4, 12(a1) \n\t"
2124 "addi a1, a1, 16 \n\t"
2125 "addi t2, %[INNER], 0 \n\t"
2126 "BLOCK_INNER_LOOP%=: \n\t"
2127
2128 LOAD_B_16x8x2
2129
2130 "vsetvli t0, zero, e8, m1 \n\t"
2131 "vle8.v v10, (a1) \n\t"
2132
2133 "addi a1, a1, 32 \n\t"
2134 "vle8.v v11, (a1) \n\t"
2135 "addi a1, a1, 32 \n\t"
2136 "vadd.vi v2, v2, -8 \n\t"
2137 "vadd.vi v3, v3, -8 \n\t"
2138 "vadd.vi v4, v4, -8 \n\t"
2139 "vadd.vi v5, v5, -8 \n\t"
2140 "vadd.vi v6, v6, -8 \n\t"
2141 "vadd.vi v7, v7, -8 \n\t"
2142 "vadd.vi v8, v8, -8 \n\t"
2143 "vadd.vi v9, v9, -8 \n\t"
2144
2145 SQ4BIT_KERNEL_COMP_4x16x16
2146
2147 "addi t2, t2, -1 \n\t"
2148 "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2149
2150 LOAD_SCALE_4x16
2151
2152 "vsetvli t0, zero, e32, m8 \n\t"
2153 "vfcvt.f.x.v v16, v16 \n\t"
2154 "vfmacc.vv v24, v16, v8 \n\t"
2155 "addi t3, t3, -1 \n\t"
2156 "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2157
2158 "RESULT_SAVE%=: \n\t"
2159
2160 SAVE_RESULT_4x16
2161
2162 :
2163 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2164 [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2165 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2166 "s4", "s5", "s6");
2167 }
2168 }
2169 }
2170 if (CountN % 16 != 0) {
2171 // stroe output from tmp to C when NBLKS less than 16.
2172 float * CPtr = C + CountN / 16 * 16;
2173 const size_t N = CountN % 16;
2174 LDC = ldc * sizeof(float);
2175 __asm__ volatile(
2176 "vsetvli t0, %[N], e32, m2 \n\t"
2177 "vle32.v v0, (%[SRC]) \n\t"
2178 "addi s2, %[SRC], 64 \n\t"
2179 "addi s3, %[SRC], 64*2 \n\t"
2180 "addi s4, %[SRC], 64*3 \n\t"
2181 "vle32.v v2, (s2) \n\t"
2182 "vle32.v v4, (s3) \n\t"
2183 "vle32.v v6, (s4) \n\t"
2184 "add t2, %[DST], %[LDC] \n\t"
2185 "add t3, t2, %[LDC] \n\t"
2186 "add t4, t3, %[LDC] \n\t"
2187 "vse32.v v0, (%[DST]) \n\t"
2188 "vse32.v v2, (t2) \n\t"
2189 "vse32.v v4, (t3) \n\t"
2190 "vse32.v v6, (t4) \n\t"
2191 :
2192 : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
2193 : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
2194 }
2195}
2196
2197template <bool HasZeroPoint>
2198void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
2199 const std::byte * QuantA,
2200 const std::byte * QuantBData,
2201 const float * QuantBScale,
2202 const std::byte * QuantBZeroPoint,
2203 float * C,
2204 size_t CountN,
2205 size_t BlockCountK,
2206 const float * Bias) {
2207 GGML_UNUSED(QuantBScale);
2208 GGML_UNUSED(QuantBZeroPoint);
2209 size_t INNER = BlkLen / 16;
2210
2211 if constexpr (HasZeroPoint) {
2212 for (size_t n = 0; n < CountN; n += 16) {
2213 size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2214 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2215 n * BlockCountK * BlkLen / 2 + // b data
2216 n * BlockCountK * sizeof(uint8_t) + // zp
2217 n * BlockCountK * sizeof(_Float16); // scale
2218 float * CPtr = C + n;
2219 size_t cnt = BlockCountK;
2220 if (Bias != nullptr) {
2221 const float * bias = Bias + n;
2222 __asm__ volatile(
2223 "addi t3, %[NBLKS], 0 \n\t"
2224 "vsetvli t0, zero, e8, m1 \n\t"
2225
2226 "vmv.v.i v13, 3 \n\t"
2227 "li s1, 24 \n\t"
2228 "vsetvli t0, s1, e8, m1 \n\t"
2229 "vmv.v.i v13, 2 \n\t"
2230 "vsetvli t0, zero, e8, mf2 \n\t"
2231 "vmv.v.i v13, 1 \n\t"
2232 "vsetvli t0, zero, e8, mf4 \n\t"
2233 "vmv.v.i v13, 0 \n\t"
2234 "addi s1, %[B], 0 \n\t"
2235 "addi s2, %[B], 8 \n\t"
2236 "addi s3, %[B], 16 \n\t"
2237 "addi s4, %[B], 24 \n\t"
2238 // zp offset
2239 "addi s7, %[B], 32 \n\t"
2240 // a offset
2241 "addi s5, %[A], 0 \n\t"
2242 "addi s6, %[A], 12 \n\t"
2243
2244 "vsetvli t0, t3, e32, mf2 \n\t"
2245 "vle32.v v28, (%[BIAS]) \n\t"
2246 "sub t3, t3, t0 \n\t"
2247 "addi %[BIAS], %[BIAS], 16 \n\t"
2248 "vsetvli t0, t3, e32, mf2 \n\t"
2249 "vle32.v v29, (%[BIAS]) \n\t"
2250 "sub t3, t3, t0 \n\t"
2251 "addi %[BIAS], %[BIAS], 16 \n\t"
2252 "vsetvli t0, t3, e32, mf2 \n\t"
2253 "vle32.v v30, (%[BIAS]) \n\t"
2254 "sub t3, t3, t0 \n\t"
2255 "addi %[BIAS], %[BIAS], 16 \n\t"
2256 "vsetvli t0, t3, e32, mf2 \n\t"
2257 "vle32.v v31, (%[BIAS]) \n\t"
2258
2259 "LOOP_K%=: \n\t"
2260 "vsetvli t0, zero, e16, mf4 \n\t"
2261
2262 "vle16.v v4, (s1) \n\t"
2263 "addi s1, s1, 48 \n\t"
2264 "vle16.v v5, (s2) \n\t"
2265 "addi s2, s2, 72 \n\t"
2266 "vle16.v v6, (s3) \n\t"
2267 "addi s3, s3, 96 \n\t"
2268 "vle16.v v7, (s4) \n\t"
2269 "addi s4, s4, 120 \n\t"
2270 "flw f1, (s5) \n\t"
2271 "addi s5, s5, 4 \n\t"
2272 "vfwcvt.f.f.v v8, v4 \n\t"
2273 "vfwcvt.f.f.v v9, v5 \n\t"
2274 "vfwcvt.f.f.v v10, v6 \n\t"
2275 "vfwcvt.f.f.v v11, v7 \n\t"
2276
2277 "vsetvli t0, zero, e32, mf2 \n\t"
2278 "addi t5, %[INNER], 0 \n\t"
2279 "vxor.vv v16, v16, v16 \n\t"
2280 "vxor.vv v18, v18, v18 \n\t"
2281 "vxor.vv v20, v20, v20 \n\t"
2282 "vxor.vv v22, v22, v22 \n\t"
2283 "vfmul.vf v24, v8, f1 \n\t"
2284 "vfmul.vf v25, v9, f1 \n\t"
2285 "vfmul.vf v26, v10, f1 \n\t"
2286 "vfmul.vf v27, v11, f1 \n\t"
2287 "addi %[CNT], %[CNT], -1 \n\t"
2288
2289 SQ4BIT_KERNEL_LOAD_ZP_16X1
2290
2291 "LOOP_INNER%=: \n\t"
2292
2293 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2294
2295 "vsub.vv v0, v0, v8 \n\t"
2296 "vsub.vv v4, v4, v8 \n\t"
2297 "vsub.vv v1, v1, v9 \n\t"
2298 "vsub.vv v5, v5, v9 \n\t"
2299 "vsub.vv v2, v2, v10 \n\t"
2300 "vsub.vv v6, v6, v10 \n\t"
2301 "vsub.vv v3, v3, v11 \n\t"
2302 "vsub.vv v7, v7, v11 \n\t"
2303
2304 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2305
2306 "bnez t5, LOOP_INNER%= \n\t"
2307 "vsetvli t0, zero, e32, mf2 \n\t"
2308
2309 SQ4BIT_KERNEL_ACC_F16_1X4X4
2310 "addi s7, s1, 32 \n\t"
2311
2312 "bnez %[CNT], LOOP_K%= \n\t"
2313 "addi t3, zero, 16 \n\t"
2314 "addi s1, %[C], 16 \n\t"
2315 "addi s2, %[C], 32 \n\t"
2316 "addi s3, %[C], 48 \n\t"
2317 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2318 "vse32.v v28, (%[C]) \n\t"
2319 "vse32.v v29, (s1) \n\t"
2320 "vse32.v v30, (s2) \n\t"
2321 "vse32.v v31, (s3) \n\t"
2322 "jal x0, END%= \n\t"
2323
2324 "ST_TAIL%=: \n\t"
2325 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2326 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2327 "vse32.v v28, (%[C]) \n\t"
2328 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2329 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2330 "vse32.v v29, (s1) \n\t"
2331 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2332 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2333 "vse32.v v30, (s2) \n\t"
2334 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2335 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2336 "vse32.v v31, (s3) \n\t"
2337 "END%=: \n\t"
2338
2339 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2340 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2341 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2342 } else {
2343 __asm__ volatile(
2344 "vsetvli t0, zero, e32, m4 \n\t"
2345 "vxor.vv v28, v28, v28 \n\t"
2346
2347 "vsetvli t0, zero, e8, m1 \n\t"
2348 "vmv.v.i v13, 3 \n\t"
2349 "li s1, 24 \n\t"
2350 "vsetvli t0, s1, e8, m1 \n\t"
2351 "vmv.v.i v13, 2 \n\t"
2352 "vsetvli t0, zero, e8, mf2 \n\t"
2353 "vmv.v.i v13, 1 \n\t"
2354 "vsetvli t0, zero, e8, mf4 \n\t"
2355 "vmv.v.i v13, 0 \n\t"
2356
2357 "addi s1, %[B], 0 \n\t"
2358 "addi s2, %[B], 8 \n\t"
2359 "addi s3, %[B], 16 \n\t"
2360 "addi s4, %[B], 24 \n\t"
2361
2362 "addi s7, %[B], 32 \n\t"
2363
2364 "addi s5, %[A], 0 \n\t"
2365 "addi s6, %[A], 12 \n\t"
2366 "LOOP_K%=: \n\t"
2367 "vsetvli t0, zero, e16, mf4 \n\t"
2368 "vle16.v v4, (s1) \n\t"
2369 "addi s1, s1, 48 \n\t"
2370 "vle16.v v5, (s2) \n\t"
2371 "addi s2, s2, 72 \n\t"
2372 "vle16.v v6, (s3) \n\t"
2373 "addi s3, s3, 96 \n\t"
2374 "vle16.v v7, (s4) \n\t"
2375 "addi s4, s4, 120 \n\t"
2376 "flw f1, (s5) \n\t"
2377 "addi s5, s5, 4 \n\t"
2378
2379 "vfwcvt.f.f.v v8, v4 \n\t"
2380 "vfwcvt.f.f.v v9, v5 \n\t"
2381 "vfwcvt.f.f.v v10, v6 \n\t"
2382 "vfwcvt.f.f.v v11, v7 \n\t"
2383 "vsetvli t0, zero, e32, mf2 \n\t"
2384
2385 "addi t5, %[INNER], 0 \n\t"
2386 "vxor.vv v16, v16, v16 \n\t"
2387 "vxor.vv v18, v18, v18 \n\t"
2388 "vxor.vv v20, v20, v20 \n\t"
2389 "vxor.vv v22, v22, v22 \n\t"
2390 "vfmul.vf v24, v8, f1 \n\t"
2391 "vfmul.vf v25, v9, f1 \n\t"
2392 "vfmul.vf v26, v10, f1 \n\t"
2393 "vfmul.vf v27, v11, f1 \n\t"
2394 "addi %[CNT], %[CNT], -1 \n\t"
2395
2396 SQ4BIT_KERNEL_LOAD_ZP_16X1
2397
2398 "LOOP_INNER%=: \n\t"
2399
2400 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2401
2402 "vsub.vv v0, v0, v8 \n\t"
2403 "vsub.vv v4, v4, v8 \n\t"
2404 "vsub.vv v1, v1, v9 \n\t"
2405 "vsub.vv v5, v5, v9 \n\t"
2406 "vsub.vv v2, v2, v10 \n\t"
2407 "vsub.vv v6, v6, v10 \n\t"
2408 "vsub.vv v3, v3, v11 \n\t"
2409 "vsub.vv v7, v7, v11 \n\t"
2410
2411 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2412
2413 "bnez t5, LOOP_INNER%= \n\t"
2414 "vsetvli t0, zero, e32, mf2 \n\t"
2415
2416 SQ4BIT_KERNEL_ACC_F16_1X4X4
2417 "addi s7, s1, 32 \n\t"
2418
2419 "bnez %[CNT], LOOP_K%= \n\t"
2420 "addi t3, zero, 16 \n\t"
2421 "addi s1, %[C], 16 \n\t"
2422 "addi s2, %[C], 32 \n\t"
2423 "addi s3, %[C], 48 \n\t"
2424 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2425 "vse32.v v28, (%[C]) \n\t"
2426 "vse32.v v29, (s1) \n\t"
2427 "vse32.v v30, (s2) \n\t"
2428 "vse32.v v31, (s3) \n\t"
2429 "jal x0, END%= \n\t"
2430
2431 "ST_TAIL%=: \n\t"
2432 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2433 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2434 "vse32.v v28, (%[C]) \n\t"
2435 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2436 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2437 "vse32.v v29, (s1) \n\t"
2438 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2439 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2440 "vse32.v v30, (s2) \n\t"
2441 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2442 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2443 "vse32.v v31, (s3) \n\t"
2444 "END%=: \n\t"
2445
2446 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2447 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2448 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2449 }
2450 }
2451 } else {
2452 for (size_t n = 0; n < CountN; n += 16) {
2453 size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2454 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2455 n * BlockCountK * BlkLen / 2 + // b data
2456 n * BlockCountK * sizeof(_Float16); // scale
2457 float * CPtr = C + n;
2458 size_t cnt = BlockCountK;
2459 if (Bias != nullptr) {
2460 const float * bias = Bias + n;
2461 __asm__ volatile(
2462 "addi t3, %[NBLKS], 0 \n\t"
2463 "addi s1, %[B], 0 \n\t"
2464 "addi s2, %[B], 8 \n\t"
2465 "addi s3, %[B], 16 \n\t"
2466 "addi s4, %[B], 24 \n\t"
2467 "addi s5, %[A], 0 \n\t"
2468 "addi s6, %[A], 12 \n\t"
2469 "vsetvli t0, t3, e32, mf2 \n\t"
2470 "vle32.v v28, (%[BIAS]) \n\t"
2471 "sub t3, t3, t0 \n\t"
2472 "addi %[BIAS], %[BIAS], 16 \n\t"
2473 "vsetvli t0, t3, e32, mf2 \n\t"
2474 "vle32.v v29, (%[BIAS]) \n\t"
2475 "sub t3, t3, t0 \n\t"
2476 "addi %[BIAS], %[BIAS], 16 \n\t"
2477 "vsetvli t0, t3, e32, mf2 \n\t"
2478 "vle32.v v30, (%[BIAS]) \n\t"
2479 "sub t3, t3, t0 \n\t"
2480 "addi %[BIAS], %[BIAS], 16 \n\t"
2481 "vsetvli t0, t3, e32, mf2 \n\t"
2482 "vle32.v v31, (%[BIAS]) \n\t"
2483
2484 "LOOP_K%=: \n\t"
2485 "vsetvli t0, zero, e16, mf4 \n\t"
2486
2487 "vle16.v v4, (s1) \n\t"
2488 "addi s1, s1, 32 \n\t"
2489 "vle16.v v5, (s2) \n\t"
2490 "addi s2, s2, 56 \n\t"
2491 "vle16.v v6, (s3) \n\t"
2492 "addi s3, s3, 80 \n\t"
2493 "vle16.v v7, (s4) \n\t"
2494 "addi s4, s4, 104 \n\t"
2495 "flw f1, (s5) \n\t"
2496 "addi s5, s5, 4 \n\t"
2497 "vfwcvt.f.f.v v8, v4 \n\t"
2498 "vfwcvt.f.f.v v9, v5 \n\t"
2499 "vfwcvt.f.f.v v10, v6 \n\t"
2500 "vfwcvt.f.f.v v11, v7 \n\t"
2501
2502 "vsetvli t0, zero, e32, mf2 \n\t"
2503 "addi t5, %[INNER], 0 \n\t"
2504 "vxor.vv v16, v16, v16 \n\t"
2505 "vxor.vv v18, v18, v18 \n\t"
2506 "vxor.vv v20, v20, v20 \n\t"
2507 "vxor.vv v22, v22, v22 \n\t"
2508 "vfmul.vf v24, v8, f1 \n\t"
2509 "vfmul.vf v25, v9, f1 \n\t"
2510 "vfmul.vf v26, v10, f1 \n\t"
2511 "vfmul.vf v27, v11, f1 \n\t"
2512 "addi %[CNT], %[CNT], -1 \n\t"
2513 "vsetvli t0, zero, e8, m1 \n\t"
2514 "LOOP_INNER%=: \n\t"
2515
2516 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2517
2518 "vadd.vi v0, v0, -8 \n\t"
2519 "vadd.vi v1, v1, -8 \n\t"
2520 "vadd.vi v2, v2, -8 \n\t"
2521 "vadd.vi v3, v3, -8 \n\t"
2522 "vadd.vi v4, v4, -8 \n\t"
2523 "vadd.vi v5, v5, -8 \n\t"
2524 "vadd.vi v6, v6, -8 \n\t"
2525 "vadd.vi v7, v7, -8 \n\t"
2526
2527 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2528
2529 "bnez t5, LOOP_INNER%= \n\t"
2530 "vsetvli t0, zero, e32, mf2 \n\t"
2531
2532 SQ4BIT_KERNEL_ACC_F16_1X4X4
2533
2534 "bnez %[CNT], LOOP_K%= \n\t"
2535 "addi t3, zero, 16 \n\t"
2536 "addi s1, %[C], 16 \n\t"
2537 "addi s2, %[C], 32 \n\t"
2538 "addi s3, %[C], 48 \n\t"
2539 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2540 "vse32.v v28, (%[C]) \n\t"
2541 "vse32.v v29, (s1) \n\t"
2542 "vse32.v v30, (s2) \n\t"
2543 "vse32.v v31, (s3) \n\t"
2544 "jal x0, END%= \n\t"
2545
2546 "ST_TAIL%=: \n\t"
2547 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2548 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2549 "vse32.v v28, (%[C]) \n\t"
2550 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2551 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2552 "vse32.v v29, (s1) \n\t"
2553 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2554 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2555 "vse32.v v30, (s2) \n\t"
2556 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2557 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2558 "vse32.v v31, (s3) \n\t"
2559 "END%=: \n\t"
2560
2561 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2562 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2563 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2564 } else {
2565 __asm__ volatile(
2566 "vsetvli t0, zero, e32, m4 \n\t"
2567 "vxor.vv v28, v28, v28 \n\t"
2568 "addi s1, %[B], 0 \n\t"
2569 "addi s2, %[B], 8 \n\t"
2570 "addi s3, %[B], 16 \n\t"
2571 "addi s4, %[B], 24 \n\t"
2572
2573 "addi s5, %[A], 0 \n\t"
2574 "addi s6, %[A], 12 \n\t"
2575 "LOOP_K%=: \n\t"
2576 "vsetvli t0, zero, e16, mf4 \n\t"
2577 "vle16.v v4, (s1) \n\t"
2578 "addi s1, s1, 32 \n\t"
2579 "vle16.v v5, (s2) \n\t"
2580 "addi s2, s2, 56 \n\t"
2581 "vle16.v v6, (s3) \n\t"
2582 "addi s3, s3, 80 \n\t"
2583 "vle16.v v7, (s4) \n\t"
2584 "addi s4, s4, 104 \n\t"
2585 "flw f1, (s5) \n\t"
2586 "addi s5, s5, 4 \n\t"
2587
2588 "vfwcvt.f.f.v v8, v4 \n\t"
2589 "vfwcvt.f.f.v v9, v5 \n\t"
2590 "vfwcvt.f.f.v v10, v6 \n\t"
2591 "vfwcvt.f.f.v v11, v7 \n\t"
2592 "vsetvli t0, zero, e32, mf2 \n\t"
2593
2594 "addi t5, %[INNER], 0 \n\t"
2595 "vxor.vv v16, v16, v16 \n\t"
2596 "vxor.vv v18, v18, v18 \n\t"
2597 "vxor.vv v20, v20, v20 \n\t"
2598 "vxor.vv v22, v22, v22 \n\t"
2599 "vfmul.vf v24, v8, f1 \n\t"
2600 "vfmul.vf v25, v9, f1 \n\t"
2601 "vfmul.vf v26, v10, f1 \n\t"
2602 "vfmul.vf v27, v11, f1 \n\t"
2603 "addi %[CNT], %[CNT], -1 \n\t"
2604 "vsetvli t0, zero, e8, m1 \n\t"
2605 "LOOP_INNER%=: \n\t"
2606
2607 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2608
2609 "vadd.vi v0, v0, -8 \n\t"
2610 "vadd.vi v1, v1, -8 \n\t"
2611 "vadd.vi v2, v2, -8 \n\t"
2612 "vadd.vi v3, v3, -8 \n\t"
2613 "vadd.vi v4, v4, -8 \n\t"
2614 "vadd.vi v5, v5, -8 \n\t"
2615 "vadd.vi v6, v6, -8 \n\t"
2616 "vadd.vi v7, v7, -8 \n\t"
2617
2618 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2619
2620 "bnez t5, LOOP_INNER%= \n\t"
2621 "vsetvli t0, zero, e32, mf2 \n\t"
2622
2623 SQ4BIT_KERNEL_ACC_F16_1X4X4
2624
2625 "bnez %[CNT], LOOP_K%= \n\t"
2626 "addi t3, zero, 16 \n\t"
2627 "addi s1, %[C], 16 \n\t"
2628 "addi s2, %[C], 32 \n\t"
2629 "addi s3, %[C], 48 \n\t"
2630 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2631 "vse32.v v28, (%[C]) \n\t"
2632 "vse32.v v29, (s1) \n\t"
2633 "vse32.v v30, (s2) \n\t"
2634 "vse32.v v31, (s3) \n\t"
2635 "jal x0, END%= \n\t"
2636
2637 "ST_TAIL%=: \n\t"
2638 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2639 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2640 "vse32.v v28, (%[C]) \n\t"
2641 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2642 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2643 "vse32.v v29, (s1) \n\t"
2644 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2645 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2646 "vse32.v v30, (s2) \n\t"
2647 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2648 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2649 "vse32.v v31, (s3) \n\t"
2650 "END%=: \n\t"
2651
2652 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2653 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2654 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2655 }
2656 }
2657 }
2658}
2659
2660template <bool HasZeroPoint>
2661void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
2662 const std::byte * QuantA,
2663 const std::byte * QuantBData,
2664 const float * QuantBScale,
2665 const std::byte * QuantBZeroPoint,
2666 float * C,
2667 size_t CountN,
2668 size_t BlockCountK,
2669 const float * Bias) {
2670 GGML_UNUSED(QuantBScale);
2671 GGML_UNUSED(QuantBZeroPoint);
2672 const size_t INNER = BlkLen / 16;
2673 if constexpr (HasZeroPoint) {
2674 for (size_t n = 0; n < CountN; n += 16) {
2675 size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2676 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2677 n * BlockCountK * BlkLen / 2 + // b data
2678 n * BlockCountK * sizeof(uint8_t) + // zp
2679 n * BlockCountK * sizeof(float); // scale
2680 float * CPtr = C + n;
2681 size_t cnt = BlockCountK;
2682 if (Bias != nullptr) {
2683 const float * bias = Bias + n;
2684 __asm__ volatile(
2685 "addi t3, %[NBLKS], 0 \n\t"
2686 "vsetvli t0, zero, e8, m1 \n\t"
2687 "vmv.v.i v13, 3 \n\t"
2688 "li s1, 24 \n\t"
2689 "vsetvli t0, s1, e8, m1 \n\t"
2690 "vmv.v.i v13, 2 \n\t"
2691 "vsetvli t0, zero, e8, mf2 \n\t"
2692 "vmv.v.i v13, 1 \n\t"
2693 "vsetvli t0, zero, e8, mf4 \n\t"
2694 "vmv.v.i v13, 0 \n\t"
2695 "vsetvli t0, zero, e32, m4 \n\t"
2696 "vxor.vv v28, v28, v28 \n\t"
2697
2698 // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
2699 "addi s1, %[B], 0 \n\t"
2700 "addi s2, %[B], 16 \n\t"
2701 "addi s3, %[B], 32 \n\t"
2702 "addi s4, %[B], 48 \n\t"
2703 // zp offset
2704 "addi s7, %[B], 64 \n\t"
2705 // a offset
2706 "addi s5, %[A], 0 \n\t"
2707 "addi s6, %[A], 12 \n\t"
2708
2709 "vsetvli t0, t3, e32, mf2 \n\t"
2710 "vle32.v v28, (%[BIAS]) \n\t"
2711 "sub t3, t3, t0 \n\t"
2712 "addi %[BIAS], %[BIAS], 16 \n\t"
2713 "vsetvli t0, t3, e32, mf2 \n\t"
2714 "vle32.v v29, (%[BIAS]) \n\t"
2715 "sub t3, t3, t0 \n\t"
2716 "addi %[BIAS], %[BIAS], 16 \n\t"
2717 "vsetvli t0, t3, e32, mf2 \n\t"
2718 "vle32.v v30, (%[BIAS]) \n\t"
2719 "sub t3, t3, t0 \n\t"
2720 "addi %[BIAS], %[BIAS], 16 \n\t"
2721 "vsetvli t0, t3, e32, mf2 \n\t"
2722 "vle32.v v31, (%[BIAS]) \n\t"
2723 "vsetvli t0, zero, e32, mf2 \n\t"
2724 "LOOP_K%=: \n\t"
2725
2726 // load scale
2727 "vle32.v v8, (s1) \n\t"
2728 "addi s1, s1, 80 \n\t"
2729 "vle32.v v9, (s2) \n\t"
2730 "addi s2, s2, 96 \n\t"
2731 "vle32.v v10, (s3) \n\t"
2732 "addi s3, s3, 112 \n\t"
2733 "vle32.v v11, (s4) \n\t"
2734 "addi s4, s4, 128 \n\t"
2735
2736 // load a scale
2737 "flw f1, (s5) \n\t"
2738 "addi s5, s5, 4 \n\t"
2739
2740 "addi t5, %[INNER], 0 \n\t"
2741 "vxor.vv v16, v16, v16 \n\t"
2742 "vxor.vv v18, v18, v18 \n\t"
2743 "vxor.vv v20, v20, v20 \n\t"
2744 "vxor.vv v22, v22, v22 \n\t"
2745
2746 // a scale * b scale
2747 "vfmul.vf v24, v8, f1 \n\t"
2748 "vfmul.vf v25, v9, f1 \n\t"
2749 "vfmul.vf v26, v10, f1 \n\t"
2750 "vfmul.vf v27, v11, f1 \n\t"
2751 "addi %[CNT], %[CNT], -1 \n\t"
2752
2753 SQ4BIT_KERNEL_LOAD_ZP_16X1
2754
2755 "LOOP_INNER%=: \n\t"
2756
2757 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2758
2759 "vsub.vv v0, v0, v8 \n\t"
2760 "vsub.vv v4, v4, v8 \n\t"
2761 "vsub.vv v1, v1, v9 \n\t"
2762 "vsub.vv v5, v5, v9 \n\t"
2763 "vsub.vv v2, v2, v10 \n\t"
2764 "vsub.vv v6, v6, v10 \n\t"
2765 "vsub.vv v3, v3, v11 \n\t"
2766 "vsub.vv v7, v7, v11 \n\t"
2767
2768 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2769
2770 "bnez t5, LOOP_INNER%= \n\t"
2771 "vsetvli t0, zero, e32, mf2 \n\t"
2772
2773 SQ4BIT_KERNEL_ACC_1X4X4
2774 "addi s7, s1, 64 \n\t"
2775
2776 "bnez %[CNT], LOOP_K%= \n\t"
2777
2778 "addi t3, zero, 16 \n\t"
2779 "addi s1, %[C], 16 \n\t"
2780 "addi s2, %[C], 32 \n\t"
2781 "addi s3, %[C], 48 \n\t"
2782 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2783 "vse32.v v28, (%[C]) \n\t"
2784 "vse32.v v29, (s1) \n\t"
2785 "vse32.v v30, (s2) \n\t"
2786 "vse32.v v31, (s3) \n\t"
2787 "jal x0, END%= \n\t"
2788
2789 "ST_TAIL%=: \n\t"
2790 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2791 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2792 "vse32.v v28, (%[C]) \n\t"
2793 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2794 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2795 "vse32.v v29, (s1) \n\t"
2796 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2797 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2798 "vse32.v v30, (s2) \n\t"
2799 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2800 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2801 "vse32.v v31, (s3) \n\t"
2802 "END%=: \n\t"
2803
2804 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2805 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2806 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2807 } else {
2808 __asm__ volatile(
2809 "vsetvli t0, zero, e32, m4 \n\t"
2810 "vxor.vv v28, v28, v28 \n\t"
2811
2812 "vsetvli t0, zero, e8, m1 \n\t"
2813 "vmv.v.i v13, 3 \n\t"
2814 "li s1, 24 \n\t"
2815 "vsetvli t0, s1, e8, m1 \n\t"
2816 "vmv.v.i v13, 2 \n\t"
2817 "vsetvli t0, zero, e8, mf2 \n\t"
2818 "vmv.v.i v13, 1 \n\t"
2819 "vsetvli t0, zero, e8, mf4 \n\t"
2820 "vmv.v.i v13, 0 \n\t"
2821 "addi s1, %[B], 0 \n\t"
2822 "addi s2, %[B], 16 \n\t"
2823 "addi s3, %[B], 32 \n\t"
2824 "addi s4, %[B], 48 \n\t"
2825
2826 "addi s7, %[B], 64 \n\t"
2827
2828 "addi s5, %[A], 0 \n\t"
2829 "addi s6, %[A], 12 \n\t"
2830 "vsetvli t0, zero, e32, mf2 \n\t"
2831
2832 "LOOP_K%=: \n\t"
2833 "vle32.v v8, (s1) \n\t"
2834 "addi s1, s1, 80 \n\t"
2835 "vle32.v v9, (s2) \n\t"
2836 "addi s2, s2, 96 \n\t"
2837 "vle32.v v10, (s3) \n\t"
2838 "addi s3, s3, 112 \n\t"
2839 "vle32.v v11, (s4) \n\t"
2840 "addi s4, s4, 128 \n\t"
2841
2842 "flw f1, (s5) \n\t"
2843 "addi s5, s5, 4 \n\t"
2844
2845 "addi t5, %[INNER], 0 \n\t"
2846 "vxor.vv v16, v16, v16 \n\t"
2847 "vxor.vv v18, v18, v18 \n\t"
2848 "vxor.vv v20, v20, v20 \n\t"
2849 "vxor.vv v22, v22, v22 \n\t"
2850
2851 "vfmul.vf v24, v8, f1 \n\t"
2852 "vfmul.vf v25, v9, f1 \n\t"
2853 "vfmul.vf v26, v10, f1 \n\t"
2854 "vfmul.vf v27, v11, f1 \n\t"
2855 "addi %[CNT], %[CNT], -1 \n\t"
2856
2857 SQ4BIT_KERNEL_LOAD_ZP_16X1
2858
2859 "LOOP_INNER%=: \n\t"
2860
2861 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2862
2863 "vsub.vv v0, v0, v8 \n\t"
2864 "vsub.vv v4, v4, v8 \n\t"
2865 "vsub.vv v1, v1, v9 \n\t"
2866 "vsub.vv v5, v5, v9 \n\t"
2867 "vsub.vv v2, v2, v10 \n\t"
2868 "vsub.vv v6, v6, v10 \n\t"
2869 "vsub.vv v3, v3, v11 \n\t"
2870 "vsub.vv v7, v7, v11 \n\t"
2871
2872 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2873
2874 "bnez t5, LOOP_INNER%= \n\t"
2875 "vsetvli t0, zero, e32, mf2 \n\t"
2876
2877 SQ4BIT_KERNEL_ACC_1X4X4
2878 "addi s7, s1, 64 \n\t"
2879
2880 "bnez %[CNT], LOOP_K%= \n\t"
2881
2882 "addi t3, zero, 16 \n\t"
2883 "addi s1, %[C], 16 \n\t"
2884 "addi s2, %[C], 32 \n\t"
2885 "addi s3, %[C], 48 \n\t"
2886 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2887 "vse32.v v28, (%[C]) \n\t"
2888 "vse32.v v29, (s1) \n\t"
2889 "vse32.v v30, (s2) \n\t"
2890 "vse32.v v31, (s3) \n\t"
2891 "jal x0, END%= \n\t"
2892
2893 "ST_TAIL%=: \n\t"
2894 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2895 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2896 "vse32.v v28, (%[C]) \n\t"
2897 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2898 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2899 "vse32.v v29, (s1) \n\t"
2900 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2901 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2902 "vse32.v v30, (s2) \n\t"
2903 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2904 "sub %[NBLKS], %[NBLKS], t0 \n\t"
2905 "vse32.v v31, (s3) \n\t"
2906 "END%=: \n\t"
2907
2908 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2909 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2910 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2911 }
2912 }
2913 } else {
2914 for (size_t n = 0; n < CountN; n += 16) {
2915 size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2916 std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2917 n * BlockCountK * BlkLen / 2 + // b data
2918 n * BlockCountK * sizeof(float); // scale
2919 float * CPtr = C + n;
2920 size_t cnt = BlockCountK;
2921 if (Bias != nullptr) {
2922 const float * bias = Bias + n;
2923 __asm__ volatile(
2924 "addi t3, %[NBLKS], 0 \n\t"
2925 "addi s1, %[B], 0 \n\t"
2926 "addi s2, %[B], 16 \n\t"
2927 "addi s3, %[B], 32 \n\t"
2928 "addi s4, %[B], 48 \n\t"
2929 "addi s5, %[A], 0 \n\t"
2930 "addi s6, %[A], 12 \n\t"
2931 "vsetvli t0, t3, e32, mf2 \n\t"
2932 "vle32.v v28, (%[BIAS]) \n\t"
2933 "sub t3, t3, t0 \n\t"
2934 "addi %[BIAS], %[BIAS], 16 \n\t"
2935 "vsetvli t0, t3, e32, mf2 \n\t"
2936 "vle32.v v29, (%[BIAS]) \n\t"
2937 "sub t3, t3, t0 \n\t"
2938 "addi %[BIAS], %[BIAS], 16 \n\t"
2939 "vsetvli t0, t3, e32, mf2 \n\t"
2940 "vle32.v v30, (%[BIAS]) \n\t"
2941 "sub t3, t3, t0 \n\t"
2942 "addi %[BIAS], %[BIAS], 16 \n\t"
2943 "vsetvli t0, t3, e32, mf2 \n\t"
2944 "vle32.v v31, (%[BIAS]) \n\t"
2945 "vsetvli t0, zero, e32, mf2 \n\t"
2946 "LOOP_K%=: \n\t"
2947 "vle32.v v8, (s1) \n\t"
2948 "addi s1, s1, 64 \n\t"
2949 "vle32.v v9, (s2) \n\t"
2950 "addi s2, s2, 80 \n\t"
2951 "vle32.v v10, (s3) \n\t"
2952 "addi s3, s3, 96 \n\t"
2953 "vle32.v v11, (s4) \n\t"
2954 "addi s4, s4, 112 \n\t"
2955 "flw f1, (s5) \n\t"
2956 "addi s5, s5, 4 \n\t"
2957
2958 "addi t5, %[INNER], 0 \n\t"
2959 "vxor.vv v16, v16, v16 \n\t"
2960 "vxor.vv v18, v18, v18 \n\t"
2961 "vxor.vv v20, v20, v20 \n\t"
2962 "vxor.vv v22, v22, v22 \n\t"
2963 "vfmul.vf v24, v8, f1 \n\t"
2964 "vfmul.vf v25, v9, f1 \n\t"
2965 "vfmul.vf v26, v10, f1 \n\t"
2966 "vfmul.vf v27, v11, f1 \n\t"
2967 "addi %[CNT], %[CNT], -1 \n\t"
2968 "vsetvli t0, zero, e8, m1 \n\t"
2969 "LOOP_INNER%=: \n\t"
2970
2971 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2972
2973 "vadd.vi v0, v0, -8 \n\t"
2974 "vadd.vi v1, v1, -8 \n\t"
2975 "vadd.vi v2, v2, -8 \n\t"
2976 "vadd.vi v3, v3, -8 \n\t"
2977 "vadd.vi v4, v4, -8 \n\t"
2978 "vadd.vi v5, v5, -8 \n\t"
2979 "vadd.vi v6, v6, -8 \n\t"
2980 "vadd.vi v7, v7, -8 \n\t"
2981
2982 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2983
2984 "bnez t5, LOOP_INNER%= \n\t"
2985 "vsetvli t0, zero, e32, mf2 \n\t"
2986
2987 SQ4BIT_KERNEL_ACC_1X4X4
2988
2989 "bnez %[CNT], LOOP_K%= \n\t"
2990 "addi t3, zero, 16 \n\t"
2991 "addi s1, %[C], 16 \n\t"
2992 "addi s2, %[C], 32 \n\t"
2993 "addi s3, %[C], 48 \n\t"
2994 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2995 "vse32.v v28, (%[C]) \n\t"
2996 "vse32.v v29, (s1) \n\t"
2997 "vse32.v v30, (s2) \n\t"
2998 "vse32.v v31, (s3) \n\t"
2999 "jal x0, END%= \n\t"
3000
3001 "ST_TAIL%=: \n\t"
3002 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3003 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3004 "vse32.v v28, (%[C]) \n\t"
3005 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3006 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3007 "vse32.v v29, (s1) \n\t"
3008 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3009 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3010 "vse32.v v30, (s2) \n\t"
3011 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3012 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3013 "vse32.v v31, (s3) \n\t"
3014 "END%=: \n\t"
3015
3016 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
3017 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3018 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3019 } else {
3020 __asm__ volatile(
3021 "vsetvli t0, zero, e32, m4 \n\t"
3022 "vxor.vv v28, v28, v28 \n\t"
3023 "addi s1, %[B], 0 \n\t"
3024 "addi s2, %[B], 16 \n\t"
3025 "addi s3, %[B], 32 \n\t"
3026 "addi s4, %[B], 48 \n\t"
3027
3028 "addi s5, %[A], 0 \n\t"
3029 "addi s6, %[A], 12 \n\t"
3030 "vsetvli t0, zero, e32, mf2 \n\t"
3031 "LOOP_K%=: \n\t"
3032 "vle32.v v8, (s1) \n\t"
3033 "addi s1, s1, 64 \n\t"
3034 "vle32.v v9, (s2) \n\t"
3035 "addi s2, s2, 80 \n\t"
3036 "vle32.v v10, (s3) \n\t"
3037 "addi s3, s3, 96 \n\t"
3038 "vle32.v v11, (s4) \n\t"
3039 "addi s4, s4, 112 \n\t"
3040 "flw f1, (s5) \n\t"
3041 "addi s5, s5, 4 \n\t"
3042
3043 "addi t5, %[INNER], 0 \n\t"
3044 "vxor.vv v16, v16, v16 \n\t"
3045 "vxor.vv v18, v18, v18 \n\t"
3046 "vxor.vv v20, v20, v20 \n\t"
3047 "vxor.vv v22, v22, v22 \n\t"
3048 "vfmul.vf v24, v8, f1 \n\t"
3049 "vfmul.vf v25, v9, f1 \n\t"
3050 "vfmul.vf v26, v10, f1 \n\t"
3051 "vfmul.vf v27, v11, f1 \n\t"
3052 "addi %[CNT], %[CNT], -1 \n\t"
3053 "vsetvli t0, zero, e8, m1 \n\t"
3054 "LOOP_INNER%=: \n\t"
3055
3056 SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
3057
3058 "vadd.vi v0, v0, -8 \n\t"
3059 "vadd.vi v1, v1, -8 \n\t"
3060 "vadd.vi v2, v2, -8 \n\t"
3061 "vadd.vi v3, v3, -8 \n\t"
3062 "vadd.vi v4, v4, -8 \n\t"
3063 "vadd.vi v5, v5, -8 \n\t"
3064 "vadd.vi v6, v6, -8 \n\t"
3065 "vadd.vi v7, v7, -8 \n\t"
3066
3067 SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
3068
3069 "bnez t5, LOOP_INNER%= \n\t"
3070 "vsetvli t0, zero, e32, mf2 \n\t"
3071
3072 SQ4BIT_KERNEL_ACC_1X4X4
3073
3074 "bnez %[CNT], LOOP_K%= \n\t"
3075 "addi t3, zero, 16 \n\t"
3076 "addi s1, %[C], 16 \n\t"
3077 "addi s2, %[C], 32 \n\t"
3078 "addi s3, %[C], 48 \n\t"
3079 "blt %[NBLKS], t3, ST_TAIL%= \n\t"
3080 "vse32.v v28, (%[C]) \n\t"
3081 "vse32.v v29, (s1) \n\t"
3082 "vse32.v v30, (s2) \n\t"
3083 "vse32.v v31, (s3) \n\t"
3084 "jal x0, END%= \n\t"
3085
3086 "ST_TAIL%=: \n\t"
3087 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3088 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3089 "vse32.v v28, (%[C]) \n\t"
3090 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3091 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3092 "vse32.v v29, (s1) \n\t"
3093 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3094 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3095 "vse32.v v30, (s2) \n\t"
3096 "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3097 "sub %[NBLKS], %[NBLKS], t0 \n\t"
3098 "vse32.v v31, (s3) \n\t"
3099 "END%=: \n\t"
3100
3101 : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
3102 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3103 : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3104 }
3105 }
3106 }
3107}
3108
3109template <bool HasZeroPoint>
3110inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3111 const std::byte * QuantA,
3112 const std::byte * QuantBData,
3113 const float * QuantBScale,
3114 const std::byte * QuantBZeroPoint,
3115 float * C,
3116 size_t CountM,
3117 size_t CountN,
3118 size_t BlockStrideQuantB,
3119 const float * Bias,
3120 const size_t ldc,
3121 const size_t scalestride) {
3122 if (scalestride == 4) {
3123 SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3124 CountN, BlockStrideQuantB, Bias, ldc);
3125
3126 } else if (scalestride == 2) {
3127 SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
3128 BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
3129 }
3130}
3131
3132template <bool HasZeroPoint>
3133inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3134 const std::byte * QuantA,
3135 const std::byte * QuantBData,
3136 const float * QuantBScale,
3137 const std::byte * QuantBZeroPoint,
3138 float * C,
3139 size_t CountM,
3140 size_t CountN,
3141 size_t BlockStrideQuantB,
3142 const float * Bias,
3143 const size_t ldc,
3144 const size_t scalestride) {
3145 if (scalestride == 4) {
3146 SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3147 CountN, BlockStrideQuantB, Bias);
3148 } else if (scalestride == 2) {
3149 SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
3150 QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
3151 }
3152}
3153
3154} // namespace
3155
3156namespace ime1 {
3157size_t gemm_kernel_i8i4(size_t BlkLen,
3158 const std::byte * QuantA,
3159 const std::byte * QuantBData,
3160 const float * QuantBScale,
3161 const std::byte * QuantBZeroPoint,
3162 float * C,
3163 size_t CountM,
3164 size_t CountN,
3165 size_t CountK,
3166 size_t BlockCountK,
3167 size_t ldc,
3168 const float * Bias,
3169 const size_t ScaleStride) {
3170 GGML_UNUSED(CountM);
3171 GGML_UNUSED(CountK);
3172 GGML_UNUSED(ldc);
3173 if (CountM >= 4) {
3174 if (QuantBZeroPoint != nullptr) {
3175 SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3176 C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3177 } else {
3178 SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3179 QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3180 ldc, ScaleStride);
3181 }
3182 return 4;
3183 } else {
3184 if (QuantBZeroPoint != nullptr) {
3185 SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3186 C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3187 } else {
3188 SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3189 QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3190 ldc, ScaleStride);
3191 }
3192 return 1;
3193 }
3194}
3195} // namespace ime1
3196} // namespace sqnbitgemm_spacemit_ime