1#define GGML_COMMON_IMPL_CPP
2#define GGML_COMMON_DECL_CPP
3
4#include "ime.h"
5
6#include "ggml-backend-impl.h"
7#include "ggml-common.h"
8#include "ggml-cpu.h"
9#include "ime_kernels.h"
10#include "traits.h"
11
12#include <algorithm>
13#include <cassert>
14#include <cmath>
15#include <cstdio> // for GGML_ASSERT
16#include <stdexcept>
17#include <thread>
18
19// clang-format off
20#if defined(__riscv)
21
22#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
23#error "riscv v extension or v_intrinsic not enabled"
24#else
25#include <riscv_vector.h>
26#endif
27
28#if !defined(__riscv_zfh)
29#error "riscv zfh extension not enabled"
30#endif
31
32#if defined(RISCV64_SPACEMIT_IME1)
33#else
34#error "RISCV64_SPACEMIT_IME1 not defined"
35#endif
36
37#else
38
39#error "riscv not enabled in this build"
40
41#endif
42
43#if defined(__GNUC__)
44#pragma GCC diagnostic ignored "-Woverlength-strings"
45#pragma GCC diagnostic ignored "-Wcast-qual"
46#pragma GCC diagnostic ignored "-Wunused-parameter"
47#endif
48
49#if defined(RISCV64_SPACEMIT_IME1)
50#define QGEMM_STRIDEN_THREAD_ALIGN 16
51#else
52#define QGEMM_STRIDEN_THREAD_ALIGN 32
53#endif
54
55// clang-format on
56
57struct qnbitgemm_spacemit_ime_args {
58 const float * a_ptr = nullptr;
59 size_t lda = 0;
60 const std::byte * packed_quant_b_data = nullptr;
61 const float * quant_b_scale = nullptr;
62 const void * quant_b_zp = nullptr;
63 const float * quant_b_blksum = nullptr;
64 const float * bias = nullptr;
65 float * c_ptr = nullptr;
66 size_t ldc = 0;
67};
68
69constexpr size_t div_round_up(size_t up, size_t down) {
70 return (up + down - 1) / down;
71}
72
73constexpr size_t q8_blk_size(size_t blk_len) {
74 const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
75 // Currently, the strictest alignment requirement of a block is for a float.
76 // Ensure contiguous blocks are suitably aligned.
77 assert(blk_size % alignof(float) == 0);
78 return blk_size;
79}
80
81namespace ggml::cpu::riscv64_spacemit {
82
83const int num_ai_cores = std::thread::hardware_concurrency() / 2;
84
85} // namespace ggml::cpu::riscv64_spacemit
86
87static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len,
88 const size_t gemm_k,
89 const qnbitgemm_spacemit_ime_args * gemm_args,
90 void * const per_gemm_ws,
91 const size_t m_start,
92 const size_t m_count,
93 const size_t n_start,
94 const size_t n_count) {
95 constexpr size_t scale_stride = sizeof(uint16_t);
96 constexpr size_t blk_bitwidth = 4;
97
98 const size_t k_blks = div_round_up(gemm_k, blk_len);
99
100 const size_t lda = k_blks * q8_blk_size(blk_len);
101 const size_t ldc = gemm_args->ldc;
102 const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8);
103 const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;
104
105 const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;
106 const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride);
107 const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;
108
109 float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;
110
111 size_t count_n = 0;
112 const size_t compute_block_count_n = m_count == 1 ? n_count : 16;
113 for (size_t n = 0; n < n_count; n += count_n) {
114 count_n = std::min(n_count - n, compute_block_count_n);
115
116 const std::byte * a_row = quant_a_ptr;
117 const std::byte * b_col = packed_quant_b_data + n * packed_b_stride;
118 const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
119 float * c_blk = c_ptr + n;
120
121 int32_t rows_remaining = m_count;
122
123 while (rows_remaining > 0) {
124 const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(
125 blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
126 scale_stride);
127
128 c_blk += rows_handled * ldc;
129 a_row += rows_handled * lda;
130
131 rows_remaining -= rows_handled;
132 }
133 }
134}
135
136template <int K> constexpr int QK_0() {
137 if constexpr (K == 4) {
138 return QK4_0;
139 }
140 if constexpr (K == 8) {
141 return QK8_0;
142 }
143 return -1;
144}
145
146template <int K, int N> struct block {
147 ggml_half d[N]; // deltas for N qK_0 blocks
148 uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
149};
150
151template <int K, int N> struct block_with_zp {
152 ggml_half d[N]; // deltas for N qK_1 blocks
153 uint8_t zp[N]; // zero points for N qK_1 blocks
154 uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks
155};
156
157// control size
158static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
159static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
160 "wrong block_with_zp<4,16> size/padding");
161static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
162
163using block_q4_0x16 = block<4, 16>;
164using block_q4_1x16 = block_with_zp<4, 16>;
165using block_q8_0x16 = block<8, 16>;
166
167static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
168 block_q4_0x16 out;
169 GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
170
171 for (int i = 0; i < 16; i++) {
172 out.d[i] = in[i].d;
173 }
174
175 for (int i = 0; i < 16; i++) {
176 // [0, 15], in.d & 0x0F
177 for (int j = 0; j < QK4_0 / 4; j++) {
178 //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
179 //dst [b0 b8] ......... [b7 b15]
180 out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
181 }
182 }
183
184 for (int i = 0; i < 16; i++) {
185 // [16, 31], in.d & 0xF0
186 for (int j = 0; j < QK4_0 / 4; j++) {
187 //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
188 //dst [b16 b24] ......... [b23 b31]
189 out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
190 }
191 }
192
193 return out;
194}
195
196static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {
197 block_q4_1x16 out;
198 GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
199
200 for (int i = 0; i < 16; i++) {
201 float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
202 float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
203 float mid = -std::nearbyintf(m / d);
204 mid = std::min(15.0f, std::max(0.0f, mid));
205 out.d[i] = GGML_FP32_TO_FP16(d);
206 out.zp[i] = static_cast<uint8_t>(mid);
207 }
208
209 for (int i = 0; i < 16; i++) {
210 // [0, 15], in.d & 0x0F
211 for (int j = 0; j < QK4_1 / 4; j++) {
212 //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
213 //dst [b0 b8] ......... [b7 b15]
214 out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
215 }
216 }
217
218 for (int i = 0; i < 16; i++) {
219 // [16, 31], in.d & 0xF0
220 for (int j = 0; j < QK4_1 / 4; j++) {
221 //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
222 //dst [b16 b24] ......... [b23 b31]
223 out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
224 }
225 }
226
227 return out;
228}
229
230static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t,
231 int interleave_block,
232 const void * GGML_RESTRICT data,
233 size_t data_size) {
234 GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
235 GGML_ASSERT(interleave_block == 16);
236
237 constexpr int nrows_interleaved = 16;
238
239 block_q4_0x16 * dst = (block_q4_0x16 *) t->data;
240 const block_q4_0 * src = (const block_q4_0 *) data;
241 block_q4_0 dst_tmp[16];
242 int nrow = ggml_nrows(t);
243 int nblocks = t->ne[0] / QK4_0;
244
245 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
246
247 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
248 return -1;
249 }
250
251 for (int b = 0; b < nrow; b += nrows_interleaved) {
252 for (int64_t x = 0; x < nblocks; x++) {
253 for (int i = 0; i < nrows_interleaved; i++) {
254 dst_tmp[i] = src[x + i * nblocks];
255 }
256 *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
257 }
258 src += nrows_interleaved * nblocks;
259 }
260 return 0;
261
262 GGML_UNUSED(data_size);
263}
264
265static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t,
266 int interleave_block,
267 const void * GGML_RESTRICT data,
268 size_t data_size) {
269 GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
270 GGML_ASSERT(interleave_block == 16);
271
272 constexpr int nrows_interleaved = 16;
273
274 block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
275 const block_q4_1 * src = (const block_q4_1 *) data;
276 block_q4_1 dst_tmp[16];
277 int nrow = ggml_nrows(t);
278 int nblocks = t->ne[0] / QK4_1;
279
280 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));
281
282 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {
283 return -1;
284 }
285
286 for (int b = 0; b < nrow; b += nrows_interleaved) {
287 for (int64_t x = 0; x < nblocks; x++) {
288 for (int i = 0; i < nrows_interleaved; i++) {
289 dst_tmp[i] = src[x + i * nblocks];
290 }
291 *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
292 }
293 src += nrows_interleaved * nblocks;
294 }
295 return 0;
296
297 GGML_UNUSED(data_size);
298}
299
300static inline void get_scale_min_k4(int j,
301 const uint8_t * GGML_RESTRICT q,
302 uint8_t * GGML_RESTRICT d,
303 uint8_t * GGML_RESTRICT m) {
304 if (j < 4) {
305 *d = q[j] & 63;
306 *m = q[j + 4] & 63;
307 } else {
308 *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
309 *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
310 }
311}
312
313static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t,
314 int interleave_block,
315 const void * GGML_RESTRICT data,
316 size_t data_size) {
317 GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
318 GGML_ASSERT(interleave_block == 16);
319 GGML_ASSERT(QK_K / QK4_1 == 8);
320
321 constexpr int nrows_interleaved = 16;
322
323 block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
324 const block_q4_K * src = (const block_q4_K *) data;
325 block_q4_1 dst_tmp[16];
326 int nrow = ggml_nrows(t);
327 int nblocks = t->ne[0] / QK_K;
328
329 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {
330 return -1;
331 }
332
333 for (int b = 0; b < nrow; b += nrows_interleaved) {
334 for (int64_t x = 0; x < nblocks; x++) {
335 for (int j = 0; j < 8; j++) {
336 for (int i = 0; i < nrows_interleaved; i++) {
337 uint8_t sc, m;
338 const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
339 const float min =
340 GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);
341 get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);
342 const float d1 = d * sc;
343 const float m1 = min * m;
344
345 dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);
346 dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);
347 // src -> [b0, b32] [b1, b33] ... [b31, b63]
348 // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]
349 const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1;
350 if (j % 2 == 0) {
351 for (int ii = 0; ii < 16; ii++) {
352 dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
353 }
354 } else {
355 for (int ii = 0; ii < 16; ii++) {
356 dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
357 }
358 }
359 }
360 *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
361 }
362 }
363 src += nrows_interleaved * nblocks;
364 }
365 return 0;
366
367 GGML_UNUSED(data_size);
368}
369
370namespace ggml::cpu::riscv64_spacemit {
371
372template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
373int repack(struct ggml_tensor *, const void *, size_t);
374
375template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
376 return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
377}
378
379template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
380 return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);
381}
382
383template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
384 return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);
385}
386
387class tensor_traits_base : public ggml::cpu::tensor_traits {
388 public:
389 virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
390};
391
392template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
393 bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
394 switch (op->op) {
395 case GGML_OP_MUL_MAT:
396 size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
397 size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
398 return true;
399 default:
400 // GGML_ABORT("fatal error");
401 break;
402 }
403 return false;
404 }
405
406 bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
407 switch (op->op) {
408 case GGML_OP_MUL_MAT:
409 if (op->src[0]->type == GGML_TYPE_Q4_0 || //
410 op->src[0]->type == GGML_TYPE_Q4_1 || //
411 op->src[0]->type == GGML_TYPE_Q4_K) {
412 forward_mul_mat_q4(params, op);
413 return true;
414 }
415 default:
416 // GGML_ABORT("fatal error");
417 break;
418 }
419 return false;
420 }
421
422 void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {
423 const ggml_tensor * src0 = op->src[0];
424 const ggml_tensor * src1 = op->src[1];
425 ggml_tensor * dst = op;
426
427 GGML_TENSOR_BINARY_OP_LOCALS
428
429 int ith = params->ith;
430 int nth = params->nth;
431
432 [[maybe_unused]] const enum ggml_type type = src0->type;
433
434 void * w_data = (void *) src0->data;
435 const float * feature = (const float *) src1->data;
436 float * output = (float *) dst->data;
437
438 const size_t batch_feature = ne12 * ne13;
439 [[maybe_unused]] const size_t batch_weight = ne02 * ne03;
440 const size_t gemm_m = ne11;
441 const size_t gemm_k = ne10;
442 const size_t gemm_n = ne01;
443
444 GGML_ASSERT(batch_weight == 1);
445
446 const size_t block_count_k = div_round_up(gemm_k, QK4_0);
447 const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);
448 const size_t per_gemm_workspace_stride =
449 div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
450 const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
451 const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1;
452
453 if (ith == 0 && params->wsize < desired_wsize) {
454 throw std::runtime_error("wsize less than desired_wsize");
455 }
456
457 std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);
458
459 for (size_t i = 0; i < batch_feature; i++) {
460 qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i;
461 qnbitgemm_args[i].lda = gemm_k;
462 qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;
463 qnbitgemm_args[i].quant_b_scale = nullptr;
464
465 if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
466 qnbitgemm_args[i].quant_b_zp = nullptr;
467 } else {
468 qnbitgemm_args[i].quant_b_zp = w_data;
469 }
470
471 qnbitgemm_args[i].bias = nullptr;
472 qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;
473 qnbitgemm_args[i].ldc = gemm_n;
474 }
475
476 const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
477 void * ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
478 const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0);
479
480 {
481 constexpr size_t block_size_m = 4;
482 size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);
483 int32_t task_count = batch_feature * per_gemm_block_count_m;
484 int32_t task_per_thread = (task_count + nth - 1) / nth;
485 int32_t start = ith * task_per_thread;
486 int32_t end = std::min((ith + 1) * task_per_thread, task_count);
487 for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
488 int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
489 int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
490 int32_t m_idx = block_idx_in_gemm * block_size_m;
491 const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
492 int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
493
494 if (rows_tobe_handled == block_size_m) {
495 const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
496 std::byte * quant_a_row_ptr =
497 static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
498 sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
499 } else {
500 while (rows_tobe_handled) {
501 const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
502 std::byte * quant_a_row_ptr = static_cast<std::byte *>(ws) +
503 gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
504 sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
505 rows_tobe_handled -= 1;
506 m_idx += 1;
507 }
508 }
509 }
510 }
511
512 ggml_barrier(params->threadpool);
513
514 if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {
515 return;
516 }
517 nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
518
519 size_t threads_per_gemm = nth / batch_feature;
520 constexpr size_t gemm_m_stride = 128;
521 size_t nc = gemm_n;
522 const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride);
523 const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
524 if (max_nc < nc) {
525 nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
526 }
527 const size_t gemm_n_stride = nc;
528 const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
529 const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
530 threads_per_gemm = thread_count_m * thread_count_n;
531
532 {
533 int task_count = batch_feature * threads_per_gemm;
534 int task_per_thread = (task_count + nth - 1) / nth;
535 int start = ith * task_per_thread;
536 int end = std::min((ith + 1) * task_per_thread, task_count);
537 for (int compute_idx = start; compute_idx < end; compute_idx++) {
538 const auto gemm_i = compute_idx / threads_per_gemm;
539 const auto blk_i = compute_idx % threads_per_gemm;
540 const auto * data = &qnbitgemm_args[gemm_i];
541
542 const auto tid_n = blk_i / thread_count_m;
543 const auto tid_m = blk_i % thread_count_m;
544
545 const size_t m_start = tid_m * gemm_m_stride;
546 const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);
547
548 const size_t n_start = tid_n * gemm_n_stride;
549 const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
550
551 void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;
552
553 sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);
554 }
555 }
556 }
557
558 int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
559 GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
560 (int) NB_COLS, (int) INTER_SIZE);
561 return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
562 }
563};
564
565class tensor_traits_common : public tensor_traits_base {
566 bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
567 switch (op->op) {
568 case GGML_OP_NORM:
569 case GGML_OP_RMS_NORM:
570 size = 0;
571 return true;
572 default:
573 // GGML_ABORT("fatal error");
574 break;
575 }
576 return false;
577 }
578
579 bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
580 switch (op->op) {
581 case GGML_OP_NORM:
582 forward_norm_f32(params, op);
583 return true;
584 case GGML_OP_RMS_NORM:
585 forward_rms_norm_f32(params, op);
586 return true;
587 default:
588 // GGML_ABORT("fatal error");
589 break;
590 }
591 return false;
592 }
593
594 void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
595 const ggml_tensor * src0 = op->src[0];
596 ggml_tensor * dst = op;
597 GGML_ASSERT(ggml_are_same_shape(src0, dst));
598 GGML_ASSERT(src0->nb[0] == sizeof(float));
599
600 const int ith = params->ith;
601 const int nth = params->nth;
602
603 GGML_TENSOR_UNARY_OP_LOCALS
604
605 float epsilon;
606 memcpy(&epsilon, dst->op_params, sizeof(float));
607
608 GGML_ASSERT(epsilon > 0.0f);
609
610 auto * input = (float *) src0->data;
611 auto * output = (float *) dst->data;
612
613 const auto hidden_size = ne00;
614 const auto task_count = ne01 * ne02 * ne03;
615 const auto task_per_thread = (task_count + nth - 1) / nth;
616
617 const auto task_begin = ith * task_per_thread;
618 const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
619
620 for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
621 auto offset = task_idx * hidden_size;
622 auto * p_input = const_cast<float *>(input + offset);
623
624 auto * p_output = output + offset;
625 auto * p_temp_output = p_output;
626 auto * p_gamma_data = (const float *) nullptr;
627 auto * p_beta_data = (const float *) nullptr;
628 size_t gvl = __riscv_vsetvlmax_e32m4();
629 vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
630 vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
631 int64_t length = hidden_size;
632 while (length > 0) {
633 gvl = __riscv_vsetvl_e32m4(length);
634 // load data
635 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
636
637 sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);
638 sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
639
640 __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
641
642 p_input += gvl;
643 p_temp_output += gvl;
644 length -= gvl;
645 }
646
647 gvl = __riscv_vsetvlmax_e32m1();
648
649 float mean = 0.f;
650 vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
651 vfloat32m1_t mean_v =
652 __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
653 mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
654 mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
655 mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
656 mean = __riscv_vfmv_f_s_f32m1_f32(mean_v);
657 mean /= hidden_size;
658
659 vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
660 __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
661 mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
662 mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
663 mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
664
665 float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
666 mean_square /= hidden_size;
667 mean_square = sqrt(mean_square - mean * mean + epsilon);
668
669 mean_square = 1.0f / mean_square;
670 length = hidden_size;
671 p_temp_output = p_output;
672
673 if (p_gamma_data == nullptr && p_beta_data == nullptr) {
674 while (length > 0) {
675 gvl = __riscv_vsetvl_e32m4(length);
676 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
677 src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
678 src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
679 __riscv_vse32_v_f32m4(p_output, src_data, gvl);
680 p_temp_output += gvl;
681 p_output += gvl;
682 length -= gvl;
683 }
684 } else if (p_beta_data == nullptr) {
685 while (length > 0) {
686 gvl = __riscv_vsetvl_e32m4(length);
687 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
688 vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
689 src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
690 src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
691 src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
692 __riscv_vse32_v_f32m4(p_output, src_data, gvl);
693 p_temp_output += gvl;
694 p_output += gvl;
695 p_gamma_data += gvl;
696 length -= gvl;
697 }
698 } else if (p_gamma_data != nullptr) {
699 while (length > 0) {
700 gvl = __riscv_vsetvl_e32m4(length);
701 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
702 vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
703 src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
704 src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
705 src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
706 vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
707 src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
708 p_beta_data += gvl;
709 __riscv_vse32_v_f32m4(p_output, src_data, gvl);
710 p_temp_output += gvl;
711 p_output += gvl;
712 p_gamma_data += gvl;
713 length -= gvl;
714 }
715 }
716 }
717 }
718
719 void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
720 const ggml_tensor * src0 = op->src[0];
721 ggml_tensor * dst = op;
722 GGML_ASSERT(ggml_are_same_shape(src0, dst));
723 GGML_ASSERT(src0->nb[0] == sizeof(float));
724
725 const int ith = params->ith;
726 const int nth = params->nth;
727
728 GGML_TENSOR_UNARY_OP_LOCALS
729
730 float epsilon;
731 memcpy(&epsilon, dst->op_params, sizeof(float));
732
733 GGML_ASSERT(epsilon > 0.0f);
734
735 auto * input = (float *) src0->data;
736 auto * output = (float *) dst->data;
737
738 const auto hidden_size = ne00;
739 const auto task_count = ne01 * ne02 * ne03;
740 const auto task_per_thread = (task_count + nth - 1) / nth;
741
742 const auto task_begin = ith * task_per_thread;
743 const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
744
745 for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
746 auto offset = task_idx * hidden_size;
747 auto * p_input = const_cast<float *>(input + offset);
748 auto * p_output = output + offset;
749 auto * p_temp_output = p_output;
750 auto * p_gamma_data = (const float *) nullptr;
751 auto * p_beta_data = (const float *) nullptr;
752
753 size_t gvl = __riscv_vsetvlmax_e32m4();
754 // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
755 vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
756 int64_t length = hidden_size;
757 while (length > 0) {
758 gvl = __riscv_vsetvl_e32m4(length);
759 // load data
760 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
761
762 sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
763
764 __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
765
766 p_input += gvl;
767 p_temp_output += gvl;
768 length -= gvl;
769 }
770
771 gvl = __riscv_vsetvlmax_e32m1();
772
773 // float mean = 0.f;
774 vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
775
776 vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
777 __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
778 mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
779 mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
780 mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
781
782 float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
783 mean_square /= hidden_size;
784
785 mean_square = sqrt(mean_square + epsilon);
786
787 mean_square = 1.0f / mean_square;
788 length = hidden_size;
789 p_temp_output = p_output;
790
791 if (p_gamma_data == nullptr && p_beta_data == nullptr) {
792 while (length > 0) {
793 gvl = __riscv_vsetvl_e32m4(length);
794 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
795 src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
796 __riscv_vse32_v_f32m4(p_output, src_data, gvl);
797 p_temp_output += gvl;
798 p_output += gvl;
799 length -= gvl;
800 }
801 } else if (p_beta_data == nullptr) {
802 while (length > 0) {
803 gvl = __riscv_vsetvl_e32m4(length);
804 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
805 vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
806 src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
807 src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
808 __riscv_vse32_v_f32m4(p_output, src_data, gvl);
809 p_temp_output += gvl;
810 p_output += gvl;
811 p_gamma_data += gvl;
812 length -= gvl;
813 }
814 } else if (p_gamma_data != nullptr) {
815 while (length > 0) {
816 gvl = __riscv_vsetvl_e32m4(length);
817 vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
818 vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
819 src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
820 src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
821 vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
822 src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
823 p_beta_data += gvl;
824 __riscv_vse32_v_f32m4(p_output, src_data, gvl);
825 p_temp_output += gvl;
826 p_output += gvl;
827 p_gamma_data += gvl;
828 length -= gvl;
829 }
830 }
831 }
832 }
833
834 int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
835 memcpy(t->data, data, data_size);
836 return 0;
837 }
838};
839
840static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
841static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;
842static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;
843static const tensor_traits_common rvv_impl;
844
845} // namespace ggml::cpu::riscv64_spacemit
846
847static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {
848 if (cur->type == GGML_TYPE_Q4_0) {
849 if (cur->ne[1] % 16 == 0) {
850 return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
851 }
852 } else if (cur->type == GGML_TYPE_Q4_1) {
853 if (cur->ne[1] % 16 == 0) {
854 return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;
855 }
856 } else if (cur->type == GGML_TYPE_Q4_K) {
857 if (cur->ne[1] % 16 == 0) {
858 return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;
859 }
860 } else if (cur->type == GGML_TYPE_F32) {
861 return &ggml::cpu::riscv64_spacemit::rvv_impl;
862 }
863
864 return nullptr;
865}
866
867static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
868 struct ggml_tensor * tensor) {
869 tensor->extra =
870 (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
871
872 GGML_UNUSED(buffer);
873
874 return GGML_STATUS_SUCCESS;
875}
876
877static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
878 struct ggml_tensor * tensor,
879 const void * data,
880 size_t offset,
881 size_t size) {
882 GGML_ASSERT(offset == 0);
883 GGML_ASSERT(size == ggml_nbytes(tensor));
884
885 auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra;
886 if (tensor_traits) {
887 auto OK = tensor_traits->repack(tensor, data, size);
888 GGML_ASSERT(OK == 0);
889 }
890
891 GGML_UNUSED(buffer);
892}
893
894static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
895 return "CPU_RISCV64_SPACEMIT";
896
897 GGML_UNUSED(buft);
898}
899
900static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
901 size_t size) {
902 ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
903
904 if (buffer == nullptr) {
905 return nullptr;
906 }
907
908 buffer->buft = buft;
909 buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
910 buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor;
911 buffer->iface.get_tensor = nullptr;
912 buffer->iface.cpy_tensor = nullptr;
913 return buffer;
914}
915
916static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
917 return 64;
918
919 GGML_UNUSED(buft);
920}
921
922static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
923 const struct ggml_tensor * tensor) {
924 for (int i = 0; i < GGML_MAX_DIMS; ++i) {
925 if (tensor->ne[i] <= 0) {
926 return 0;
927 }
928 }
929
930 size_t nbytes;
931 const size_t blck_size = ggml_blck_size(tensor->type);
932 if (blck_size == 1) {
933 nbytes = ggml_type_size(tensor->type);
934 for (int i = 0; i < GGML_MAX_DIMS; ++i) {
935 nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
936 }
937 } else {
938 nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
939 if (tensor->type == GGML_TYPE_Q4_K) {
940 GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);
941 nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
942 for (int i = 1; i < GGML_MAX_DIMS; ++i) {
943 nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
944 }
945 } else {
946 for (int i = 1; i < GGML_MAX_DIMS; ++i) {
947 nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
948 }
949 }
950 }
951
952 GGML_UNUSED(buft);
953 return nbytes;
954}
955
956namespace ggml::cpu::riscv64_spacemit {
957
958class extra_buffer_type : ggml::cpu::extra_buffer_type {
959 bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
960 switch (op->op) {
961 case GGML_OP_MUL_MAT:
962 if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
963 op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
964 ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
965 if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
966 return false;
967 }
968 if (op->src[1]->type == GGML_TYPE_F32) {
969 return true;
970 }
971 }
972 break;
973 case GGML_OP_NORM:
974 case GGML_OP_RMS_NORM:
975 if (op->src[0]->type == GGML_TYPE_F32) {
976 return true;
977 }
978 break;
979 default:
980 // GGML_ABORT("fatal error");
981 break;
982 }
983 return false;
984 }
985
986 ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
987 switch (op->op) {
988 case GGML_OP_MUL_MAT:
989 if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
990 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
991 }
992 break;
993 case GGML_OP_NORM:
994 case GGML_OP_RMS_NORM:
995 return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
996 default:
997 // GGML_ABORT("fatal error");
998 break;
999 }
1000
1001 return nullptr;
1002 }
1003};
1004
1005} // namespace ggml::cpu::riscv64_spacemit
1006
1007ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
1008 static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
1009 /* .iface = */
1010 {
1011 /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
1012 /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer,
1013 /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment,
1014 /* .get_max_size = */ nullptr,
1015 /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes,
1016 /* .is_host = */ nullptr,
1017 },
1018 /* .device = */
1019 ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
1020 /* .context = */
1021 new ggml::cpu::riscv64_spacemit::extra_buffer_type(),
1022 };
1023
1024 return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
1025}