summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cpu/spacemit
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/spacemit')
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp1025
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h13
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp3196
-rw-r--r--llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h26
4 files changed, 4260 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp
new file mode 100644
index 0000000..91fe192
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp
@@ -0,0 +1,1025 @@
+#define GGML_COMMON_IMPL_CPP
+#define GGML_COMMON_DECL_CPP
+
+#include "ime.h"
+
+#include "ggml-backend-impl.h"
+#include "ggml-common.h"
+#include "ggml-cpu.h"
+#include "ime_kernels.h"
+#include "traits.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdio> // for GGML_ASSERT
+#include <stdexcept>
+#include <thread>
+
+// clang-format off
+#if defined(__riscv)
+
+#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
+#error "riscv v extension or v_intrinsic not enabled"
+#else
+#include <riscv_vector.h>
+#endif
+
+#if !defined(__riscv_zfh)
+#error "riscv zfh extension not enabled"
+#endif
+
+#if defined(RISCV64_SPACEMIT_IME1)
+#else
+#error "RISCV64_SPACEMIT_IME1 not defined"
+#endif
+
+#else
+
+#error "riscv not enabled in this build"
+
+#endif
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#pragma GCC diagnostic ignored "-Wcast-qual"
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+
+#if defined(RISCV64_SPACEMIT_IME1)
+#define QGEMM_STRIDEN_THREAD_ALIGN 16
+#else
+#define QGEMM_STRIDEN_THREAD_ALIGN 32
+#endif
+
+// clang-format on
+
+struct qnbitgemm_spacemit_ime_args {
+ const float * a_ptr = nullptr;
+ size_t lda = 0;
+ const std::byte * packed_quant_b_data = nullptr;
+ const float * quant_b_scale = nullptr;
+ const void * quant_b_zp = nullptr;
+ const float * quant_b_blksum = nullptr;
+ const float * bias = nullptr;
+ float * c_ptr = nullptr;
+ size_t ldc = 0;
+};
+
+constexpr size_t div_round_up(size_t up, size_t down) {
+ return (up + down - 1) / down;
+}
+
+constexpr size_t q8_blk_size(size_t blk_len) {
+ const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
+ // Currently, the strictest alignment requirement of a block is for a float.
+ // Ensure contiguous blocks are suitably aligned.
+ assert(blk_size % alignof(float) == 0);
+ return blk_size;
+}
+
+namespace ggml::cpu::riscv64_spacemit {
+
+const int num_ai_cores = std::thread::hardware_concurrency() / 2;
+
+} // namespace ggml::cpu::riscv64_spacemit
+
+static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len,
+ const size_t gemm_k,
+ const qnbitgemm_spacemit_ime_args * gemm_args,
+ void * const per_gemm_ws,
+ const size_t m_start,
+ const size_t m_count,
+ const size_t n_start,
+ const size_t n_count) {
+ constexpr size_t scale_stride = sizeof(uint16_t);
+ constexpr size_t blk_bitwidth = 4;
+
+ const size_t k_blks = div_round_up(gemm_k, blk_len);
+
+ const size_t lda = k_blks * q8_blk_size(blk_len);
+ const size_t ldc = gemm_args->ldc;
+ const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8);
+ const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;
+
+ const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;
+ const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride);
+ const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;
+
+ float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;
+
+ size_t count_n = 0;
+ const size_t compute_block_count_n = m_count == 1 ? n_count : 16;
+ for (size_t n = 0; n < n_count; n += count_n) {
+ count_n = std::min(n_count - n, compute_block_count_n);
+
+ const std::byte * a_row = quant_a_ptr;
+ const std::byte * b_col = packed_quant_b_data + n * packed_b_stride;
+ const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
+ float * c_blk = c_ptr + n;
+
+ int32_t rows_remaining = m_count;
+
+ while (rows_remaining > 0) {
+ const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(
+ blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
+ scale_stride);
+
+ c_blk += rows_handled * ldc;
+ a_row += rows_handled * lda;
+
+ rows_remaining -= rows_handled;
+ }
+ }
+}
+
+template <int K> constexpr int QK_0() {
+ if constexpr (K == 4) {
+ return QK4_0;
+ }
+ if constexpr (K == 8) {
+ return QK8_0;
+ }
+ return -1;
+}
+
+template <int K, int N> struct block {
+ ggml_half d[N]; // deltas for N qK_0 blocks
+ uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
+};
+
+template <int K, int N> struct block_with_zp {
+ ggml_half d[N]; // deltas for N qK_1 blocks
+ uint8_t zp[N]; // zero points for N qK_1 blocks
+ uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks
+};
+
+// control size
+static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
+static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
+ "wrong block_with_zp<4,16> size/padding");
+static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
+
+using block_q4_0x16 = block<4, 16>;
+using block_q4_1x16 = block_with_zp<4, 16>;
+using block_q8_0x16 = block<8, 16>;
+
+static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
+ block_q4_0x16 out;
+ GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
+
+ for (int i = 0; i < 16; i++) {
+ out.d[i] = in[i].d;
+ }
+
+ for (int i = 0; i < 16; i++) {
+ // [0, 15], in.d & 0x0F
+ for (int j = 0; j < QK4_0 / 4; j++) {
+ //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+ //dst [b0 b8] ......... [b7 b15]
+ out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
+ }
+ }
+
+ for (int i = 0; i < 16; i++) {
+ // [16, 31], in.d & 0xF0
+ for (int j = 0; j < QK4_0 / 4; j++) {
+ //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+ //dst [b16 b24] ......... [b23 b31]
+ out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
+ }
+ }
+
+ return out;
+}
+
+static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {
+ block_q4_1x16 out;
+ GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
+
+ for (int i = 0; i < 16; i++) {
+ float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
+ float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
+ float mid = -std::nearbyintf(m / d);
+ mid = std::min(15.0f, std::max(0.0f, mid));
+ out.d[i] = GGML_FP32_TO_FP16(d);
+ out.zp[i] = static_cast<uint8_t>(mid);
+ }
+
+ for (int i = 0; i < 16; i++) {
+ // [0, 15], in.d & 0x0F
+ for (int j = 0; j < QK4_1 / 4; j++) {
+ //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+ //dst [b0 b8] ......... [b7 b15]
+ out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
+ }
+ }
+
+ for (int i = 0; i < 16; i++) {
+ // [16, 31], in.d & 0xF0
+ for (int j = 0; j < QK4_1 / 4; j++) {
+ //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+ //dst [b16 b24] ......... [b23 b31]
+ out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
+ }
+ }
+
+ return out;
+}
+
+static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t,
+ int interleave_block,
+ const void * GGML_RESTRICT data,
+ size_t data_size) {
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+ GGML_ASSERT(interleave_block == 16);
+
+ constexpr int nrows_interleaved = 16;
+
+ block_q4_0x16 * dst = (block_q4_0x16 *) t->data;
+ const block_q4_0 * src = (const block_q4_0 *) data;
+ block_q4_0 dst_tmp[16];
+ int nrow = ggml_nrows(t);
+ int nblocks = t->ne[0] / QK4_0;
+
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
+ return -1;
+ }
+
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
+ for (int64_t x = 0; x < nblocks; x++) {
+ for (int i = 0; i < nrows_interleaved; i++) {
+ dst_tmp[i] = src[x + i * nblocks];
+ }
+ *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
+ }
+ src += nrows_interleaved * nblocks;
+ }
+ return 0;
+
+ GGML_UNUSED(data_size);
+}
+
+static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t,
+ int interleave_block,
+ const void * GGML_RESTRICT data,
+ size_t data_size) {
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
+ GGML_ASSERT(interleave_block == 16);
+
+ constexpr int nrows_interleaved = 16;
+
+ block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
+ const block_q4_1 * src = (const block_q4_1 *) data;
+ block_q4_1 dst_tmp[16];
+ int nrow = ggml_nrows(t);
+ int nblocks = t->ne[0] / QK4_1;
+
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));
+
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {
+ return -1;
+ }
+
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
+ for (int64_t x = 0; x < nblocks; x++) {
+ for (int i = 0; i < nrows_interleaved; i++) {
+ dst_tmp[i] = src[x + i * nblocks];
+ }
+ *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
+ }
+ src += nrows_interleaved * nblocks;
+ }
+ return 0;
+
+ GGML_UNUSED(data_size);
+}
+
+static inline void get_scale_min_k4(int j,
+ const uint8_t * GGML_RESTRICT q,
+ uint8_t * GGML_RESTRICT d,
+ uint8_t * GGML_RESTRICT m) {
+ if (j < 4) {
+ *d = q[j] & 63;
+ *m = q[j + 4] & 63;
+ } else {
+ *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
+ *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
+ }
+}
+
+static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t,
+ int interleave_block,
+ const void * GGML_RESTRICT data,
+ size_t data_size) {
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
+ GGML_ASSERT(interleave_block == 16);
+ GGML_ASSERT(QK_K / QK4_1 == 8);
+
+ constexpr int nrows_interleaved = 16;
+
+ block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
+ const block_q4_K * src = (const block_q4_K *) data;
+ block_q4_1 dst_tmp[16];
+ int nrow = ggml_nrows(t);
+ int nblocks = t->ne[0] / QK_K;
+
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {
+ return -1;
+ }
+
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
+ for (int64_t x = 0; x < nblocks; x++) {
+ for (int j = 0; j < 8; j++) {
+ for (int i = 0; i < nrows_interleaved; i++) {
+ uint8_t sc, m;
+ const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
+ const float min =
+ GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);
+ get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);
+ const float d1 = d * sc;
+ const float m1 = min * m;
+
+ dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);
+ dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);
+ // src -> [b0, b32] [b1, b33] ... [b31, b63]
+ // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]
+ const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1;
+ if (j % 2 == 0) {
+ for (int ii = 0; ii < 16; ii++) {
+ dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
+ }
+ } else {
+ for (int ii = 0; ii < 16; ii++) {
+ dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
+ }
+ }
+ }
+ *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
+ }
+ }
+ src += nrows_interleaved * nblocks;
+ }
+ return 0;
+
+ GGML_UNUSED(data_size);
+}
+
+namespace ggml::cpu::riscv64_spacemit {
+
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
+int repack(struct ggml_tensor *, const void *, size_t);
+
+template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
+ return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
+}
+
+template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
+ return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);
+}
+
+template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
+ return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);
+}
+
+class tensor_traits_base : public ggml::cpu::tensor_traits {
+ public:
+ virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
+};
+
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+ switch (op->op) {
+ case GGML_OP_MUL_MAT:
+ size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
+ size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
+ return true;
+ default:
+ // GGML_ABORT("fatal error");
+ break;
+ }
+ return false;
+ }
+
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
+ switch (op->op) {
+ case GGML_OP_MUL_MAT:
+ if (op->src[0]->type == GGML_TYPE_Q4_0 || //
+ op->src[0]->type == GGML_TYPE_Q4_1 || //
+ op->src[0]->type == GGML_TYPE_Q4_K) {
+ forward_mul_mat_q4(params, op);
+ return true;
+ }
+ default:
+ // GGML_ABORT("fatal error");
+ break;
+ }
+ return false;
+ }
+
+ void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {
+ const ggml_tensor * src0 = op->src[0];
+ const ggml_tensor * src1 = op->src[1];
+ ggml_tensor * dst = op;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int ith = params->ith;
+ int nth = params->nth;
+
+ [[maybe_unused]] const enum ggml_type type = src0->type;
+
+ void * w_data = (void *) src0->data;
+ const float * feature = (const float *) src1->data;
+ float * output = (float *) dst->data;
+
+ const size_t batch_feature = ne12 * ne13;
+ [[maybe_unused]] const size_t batch_weight = ne02 * ne03;
+ const size_t gemm_m = ne11;
+ const size_t gemm_k = ne10;
+ const size_t gemm_n = ne01;
+
+ GGML_ASSERT(batch_weight == 1);
+
+ const size_t block_count_k = div_round_up(gemm_k, QK4_0);
+ const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);
+ const size_t per_gemm_workspace_stride =
+ div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
+ const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
+ const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1;
+
+ if (ith == 0 && params->wsize < desired_wsize) {
+ throw std::runtime_error("wsize less than desired_wsize");
+ }
+
+ std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);
+
+ for (size_t i = 0; i < batch_feature; i++) {
+ qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i;
+ qnbitgemm_args[i].lda = gemm_k;
+ qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;
+ qnbitgemm_args[i].quant_b_scale = nullptr;
+
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
+ qnbitgemm_args[i].quant_b_zp = nullptr;
+ } else {
+ qnbitgemm_args[i].quant_b_zp = w_data;
+ }
+
+ qnbitgemm_args[i].bias = nullptr;
+ qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;
+ qnbitgemm_args[i].ldc = gemm_n;
+ }
+
+ const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
+ void * ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
+ const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0);
+
+ {
+ constexpr size_t block_size_m = 4;
+ size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);
+ int32_t task_count = batch_feature * per_gemm_block_count_m;
+ int32_t task_per_thread = (task_count + nth - 1) / nth;
+ int32_t start = ith * task_per_thread;
+ int32_t end = std::min((ith + 1) * task_per_thread, task_count);
+ for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
+ int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
+ int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
+ int32_t m_idx = block_idx_in_gemm * block_size_m;
+ const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
+ int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
+
+ if (rows_tobe_handled == block_size_m) {
+ const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
+ std::byte * quant_a_row_ptr =
+ static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
+ sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
+ } else {
+ while (rows_tobe_handled) {
+ const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
+ std::byte * quant_a_row_ptr = static_cast<std::byte *>(ws) +
+ gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
+ sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
+ rows_tobe_handled -= 1;
+ m_idx += 1;
+ }
+ }
+ }
+ }
+
+ ggml_barrier(params->threadpool);
+
+ if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {
+ return;
+ }
+ nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
+
+ size_t threads_per_gemm = nth / batch_feature;
+ constexpr size_t gemm_m_stride = 128;
+ size_t nc = gemm_n;
+ const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride);
+ const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
+ if (max_nc < nc) {
+ nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
+ }
+ const size_t gemm_n_stride = nc;
+ const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
+ const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
+ threads_per_gemm = thread_count_m * thread_count_n;
+
+ {
+ int task_count = batch_feature * threads_per_gemm;
+ int task_per_thread = (task_count + nth - 1) / nth;
+ int start = ith * task_per_thread;
+ int end = std::min((ith + 1) * task_per_thread, task_count);
+ for (int compute_idx = start; compute_idx < end; compute_idx++) {
+ const auto gemm_i = compute_idx / threads_per_gemm;
+ const auto blk_i = compute_idx % threads_per_gemm;
+ const auto * data = &qnbitgemm_args[gemm_i];
+
+ const auto tid_n = blk_i / thread_count_m;
+ const auto tid_m = blk_i % thread_count_m;
+
+ const size_t m_start = tid_m * gemm_m_stride;
+ const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);
+
+ const size_t n_start = tid_n * gemm_n_stride;
+ const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
+
+ void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;
+
+ sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);
+ }
+ }
+ }
+
+ int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
+ GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
+ (int) NB_COLS, (int) INTER_SIZE);
+ return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
+ }
+};
+
+class tensor_traits_common : public tensor_traits_base {
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+ switch (op->op) {
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ size = 0;
+ return true;
+ default:
+ // GGML_ABORT("fatal error");
+ break;
+ }
+ return false;
+ }
+
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
+ switch (op->op) {
+ case GGML_OP_NORM:
+ forward_norm_f32(params, op);
+ return true;
+ case GGML_OP_RMS_NORM:
+ forward_rms_norm_f32(params, op);
+ return true;
+ default:
+ // GGML_ABORT("fatal error");
+ break;
+ }
+ return false;
+ }
+
+ void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
+ const ggml_tensor * src0 = op->src[0];
+ ggml_tensor * dst = op;
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float epsilon;
+ memcpy(&epsilon, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(epsilon > 0.0f);
+
+ auto * input = (float *) src0->data;
+ auto * output = (float *) dst->data;
+
+ const auto hidden_size = ne00;
+ const auto task_count = ne01 * ne02 * ne03;
+ const auto task_per_thread = (task_count + nth - 1) / nth;
+
+ const auto task_begin = ith * task_per_thread;
+ const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
+
+ for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
+ auto offset = task_idx * hidden_size;
+ auto * p_input = const_cast<float *>(input + offset);
+
+ auto * p_output = output + offset;
+ auto * p_temp_output = p_output;
+ auto * p_gamma_data = (const float *) nullptr;
+ auto * p_beta_data = (const float *) nullptr;
+ size_t gvl = __riscv_vsetvlmax_e32m4();
+ vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+ vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+ int64_t length = hidden_size;
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ // load data
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
+
+ sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);
+ sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
+
+ __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
+
+ p_input += gvl;
+ p_temp_output += gvl;
+ length -= gvl;
+ }
+
+ gvl = __riscv_vsetvlmax_e32m1();
+
+ float mean = 0.f;
+ vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
+ vfloat32m1_t mean_v =
+ __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
+ mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
+ mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
+ mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
+ mean = __riscv_vfmv_f_s_f32m1_f32(mean_v);
+ mean /= hidden_size;
+
+ vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
+ __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
+ mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
+ mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
+ mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
+
+ float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
+ mean_square /= hidden_size;
+ mean_square = sqrt(mean_square - mean * mean + epsilon);
+
+ mean_square = 1.0f / mean_square;
+ length = hidden_size;
+ p_temp_output = p_output;
+
+ if (p_gamma_data == nullptr && p_beta_data == nullptr) {
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+ src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
+ src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+ __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+ p_temp_output += gvl;
+ p_output += gvl;
+ length -= gvl;
+ }
+ } else if (p_beta_data == nullptr) {
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+ vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+ src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
+ src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+ src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+ __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+ p_temp_output += gvl;
+ p_output += gvl;
+ p_gamma_data += gvl;
+ length -= gvl;
+ }
+ } else if (p_gamma_data != nullptr) {
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+ vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+ src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
+ src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+ src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+ vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
+ src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
+ p_beta_data += gvl;
+ __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+ p_temp_output += gvl;
+ p_output += gvl;
+ p_gamma_data += gvl;
+ length -= gvl;
+ }
+ }
+ }
+ }
+
+ void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
+ const ggml_tensor * src0 = op->src[0];
+ ggml_tensor * dst = op;
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float epsilon;
+ memcpy(&epsilon, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(epsilon > 0.0f);
+
+ auto * input = (float *) src0->data;
+ auto * output = (float *) dst->data;
+
+ const auto hidden_size = ne00;
+ const auto task_count = ne01 * ne02 * ne03;
+ const auto task_per_thread = (task_count + nth - 1) / nth;
+
+ const auto task_begin = ith * task_per_thread;
+ const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
+
+ for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
+ auto offset = task_idx * hidden_size;
+ auto * p_input = const_cast<float *>(input + offset);
+ auto * p_output = output + offset;
+ auto * p_temp_output = p_output;
+ auto * p_gamma_data = (const float *) nullptr;
+ auto * p_beta_data = (const float *) nullptr;
+
+ size_t gvl = __riscv_vsetvlmax_e32m4();
+ // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+ vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+ int64_t length = hidden_size;
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ // load data
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
+
+ sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
+
+ __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
+
+ p_input += gvl;
+ p_temp_output += gvl;
+ length -= gvl;
+ }
+
+ gvl = __riscv_vsetvlmax_e32m1();
+
+ // float mean = 0.f;
+ vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
+
+ vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
+ __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
+ mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
+ mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
+ mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
+
+ float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
+ mean_square /= hidden_size;
+
+ mean_square = sqrt(mean_square + epsilon);
+
+ mean_square = 1.0f / mean_square;
+ length = hidden_size;
+ p_temp_output = p_output;
+
+ if (p_gamma_data == nullptr && p_beta_data == nullptr) {
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+ src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+ __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+ p_temp_output += gvl;
+ p_output += gvl;
+ length -= gvl;
+ }
+ } else if (p_beta_data == nullptr) {
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+ vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+ src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+ src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+ __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+ p_temp_output += gvl;
+ p_output += gvl;
+ p_gamma_data += gvl;
+ length -= gvl;
+ }
+ } else if (p_gamma_data != nullptr) {
+ while (length > 0) {
+ gvl = __riscv_vsetvl_e32m4(length);
+ vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+ vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+ src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+ src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+ vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
+ src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
+ p_beta_data += gvl;
+ __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+ p_temp_output += gvl;
+ p_output += gvl;
+ p_gamma_data += gvl;
+ length -= gvl;
+ }
+ }
+ }
+ }
+
+ int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
+ memcpy(t->data, data, data_size);
+ return 0;
+ }
+};
+
+static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
+static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;
+static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;
+static const tensor_traits_common rvv_impl;
+
+} // namespace ggml::cpu::riscv64_spacemit
+
+static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {
+ if (cur->type == GGML_TYPE_Q4_0) {
+ if (cur->ne[1] % 16 == 0) {
+ return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
+ }
+ } else if (cur->type == GGML_TYPE_Q4_1) {
+ if (cur->ne[1] % 16 == 0) {
+ return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;
+ }
+ } else if (cur->type == GGML_TYPE_Q4_K) {
+ if (cur->ne[1] % 16 == 0) {
+ return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;
+ }
+ } else if (cur->type == GGML_TYPE_F32) {
+ return &ggml::cpu::riscv64_spacemit::rvv_impl;
+ }
+
+ return nullptr;
+}
+
+static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
+ struct ggml_tensor * tensor) {
+ tensor->extra =
+ (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
+
+ GGML_UNUSED(buffer);
+
+ return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
+ struct ggml_tensor * tensor,
+ const void * data,
+ size_t offset,
+ size_t size) {
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+
+ auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra;
+ if (tensor_traits) {
+ auto OK = tensor_traits->repack(tensor, data, size);
+ GGML_ASSERT(OK == 0);
+ }
+
+ GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ return "CPU_RISCV64_SPACEMIT";
+
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+ size_t size) {
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+ if (buffer == nullptr) {
+ return nullptr;
+ }
+
+ buffer->buft = buft;
+ buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
+ buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor;
+ buffer->iface.get_tensor = nullptr;
+ buffer->iface.cpy_tensor = nullptr;
+ return buffer;
+}
+
+static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 64;
+
+ GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
+ const struct ggml_tensor * tensor) {
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ if (tensor->ne[i] <= 0) {
+ return 0;
+ }
+ }
+
+ size_t nbytes;
+ const size_t blck_size = ggml_blck_size(tensor->type);
+ if (blck_size == 1) {
+ nbytes = ggml_type_size(tensor->type);
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
+ }
+ } else {
+ nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
+ if (tensor->type == GGML_TYPE_Q4_K) {
+ GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);
+ nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
+ }
+ } else {
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
+ }
+ }
+ }
+
+ GGML_UNUSED(buft);
+ return nbytes;
+}
+
+namespace ggml::cpu::riscv64_spacemit {
+
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
+ bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+ switch (op->op) {
+ case GGML_OP_MUL_MAT:
+ if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
+ op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
+ ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+ return false;
+ }
+ if (op->src[1]->type == GGML_TYPE_F32) {
+ return true;
+ }
+ }
+ break;
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ if (op->src[0]->type == GGML_TYPE_F32) {
+ return true;
+ }
+ break;
+ default:
+ // GGML_ABORT("fatal error");
+ break;
+ }
+ return false;
+ }
+
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
+ switch (op->op) {
+ case GGML_OP_MUL_MAT:
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
+ }
+ break;
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
+ default:
+ // GGML_ABORT("fatal error");
+ break;
+ }
+
+ return nullptr;
+ }
+};
+
+} // namespace ggml::cpu::riscv64_spacemit
+
+ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
+ /* .iface = */
+ {
+ /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment,
+ /* .get_max_size = */ nullptr,
+ /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes,
+ /* .is_host = */ nullptr,
+ },
+ /* .device = */
+ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+ /* .context = */
+ new ggml::cpu::riscv64_spacemit::extra_buffer_type(),
+ };
+
+ return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
+}
diff --git a/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h
new file mode 100644
index 0000000..800d91a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "ggml-alloc.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
new file mode 100644
index 0000000..cbbb6cd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
@@ -0,0 +1,3196 @@
+#include "ggml.h"
+#include "ime_kernels.h"
+
+#include <algorithm>
+#include <cmath>
+
+// clang-format off
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#pragma GCC diagnostic ignored "-Wcast-qual"
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+// clang-format on
+namespace sqnbitgemm_spacemit_ime {
+
+#define QUANTIZEM4ROW_KERNEL \
+ "vmv.s.x v16, zero \n\t" \
+ "vfabs.v v8, v0 \n\t" \
+ "vfredmax.vs v16, v8, v16 \n\t" \
+ "vfmv.f.s f10, v16 \n\t" \
+ "fmul.s f10, f10, %[RMAXREC] \n\t" \
+ "fsw f10, (a1) \n\t" \
+ "fdiv.s f11, %[FONE], f10 \n\t" \
+ "vfmul.vf v16, v0, f11 \n\t" \
+ "vfcvt.x.f.v v16, v16 \n\t" \
+ "vsetvli t0, zero, e16, mf2 \n\t" \
+ "vnclip.wx v16, v16, zero \n\t" \
+ "vnclip.wx v17, v17, zero \n\t" \
+ "vnclip.wx v18, v18, zero \n\t" \
+ "vnclip.wx v19, v19, zero \n\t" \
+ "vnclip.wx v20, v20, zero \n\t" \
+ "vnclip.wx v21, v21, zero \n\t" \
+ "vnclip.wx v22, v22, zero \n\t" \
+ "vnclip.wx v23, v23, zero \n\t" \
+ "vsetvli t0, zero, e8, mf4 \n\t" \
+ "vnclip.wx v24, v16, zero \n\t" \
+ "vnclip.wx v25, v17, zero \n\t" \
+ "vnclip.wx v26, v18, zero \n\t" \
+ "vnclip.wx v27, v19, zero \n\t" \
+ "vnclip.wx v28, v20, zero \n\t" \
+ "vnclip.wx v29, v21, zero \n\t" \
+ "vnclip.wx v30, v22, zero \n\t" \
+ "vnclip.wx v31, v23, zero \n\t"
+
+#define QUANTIZEM4ROW_STORE \
+ "addi t1, %[BlkLen], 0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v24, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v25, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v26, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v27, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v28, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v29, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v30, (s1) \n\t" \
+ "addi s1, s1, 32 \n\t" \
+ "sub t1, t1, t0 \n\t" \
+ "vsetvli t0, t1, e8, mf4 \n\t" \
+ "vse8.v v31, (s1) \n\t"
+
+namespace ime1 {
+void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
+ const float fone = 1.0f;
+
+ if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
+ const float * SRC = A + row_index * CountK;
+ std::byte * DST = QuantA + row_index * sizeof(float);
+
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "addi t2, %[CountK], 0 \n\t"
+ "addi a1, %[DST], 0 \n\t"
+ "blt t2, %[BlkLen], TAIL%= \n\t"
+
+ "LOOP%=: \n\t"
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "sub t2, t2, t0 \n\t"
+ "slli t1, t0, 2 \n\t"
+ "add %[SRC], %[SRC], t1 \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
+
+ "add a1, a1, %[STRIDE] \n\t"
+ "bge t2, %[BlkLen], LOOP%= \n\t"
+
+ "TAIL%=: \n\t"
+ "blez t2, QUIT%= \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "vsetvli t0, t2, e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+
+ QUANTIZEM4ROW_KERNEL
+
+ "addi t3, %[BlkLen], 0 \n\t"
+ "addi s2, s1, 0 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vxor.vv v8, v8, v8 \n\t"
+ "SET_ZERO%=: \n\t"
+ "vse8.v v8, (s2) \n\t"
+ "addi s2, s2, 32 \n\t"
+ "addi t3, t3, -8 \n\t"
+ "bnez t3, SET_ZERO%= \n\t"
+
+ QUANTIZEM4ROW_STORE
+
+ "QUIT%=: \n\t"
+ : [SRC] "+r"(SRC)
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
+ }
+ } else if (BlkLen == 128) {
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
+ const float * SRC = A + row_index * CountK;
+ std::byte * DST = QuantA + row_index * sizeof(float);
+
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "li t6, 32 \n\t"
+ "addi t2, %[CountK], 0 \n\t"
+ "addi a1, %[DST], 0 \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+ "blt t2, %[BlkLen], TAIL%= \n\t"
+
+ "LOOP%=: \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "addi t2, t2, -128 \n\t"
+
+ "QUANTIZE%=: \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vfabs.v v24, v8 \n\t"
+ "vfmax.vv v16, v24, v16 \n\t"
+ "vfredmax.vs v24, v16, v24 \n\t"
+ "vfmv.f.s f10, v24 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (a1) \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vfmul.vf v16, v0, f11 \n\t"
+ "vfmul.vf v24, v8, f11 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v20, v24, zero \n\t"
+ "vsetvli t0, zero, e8, m4 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vsetvli t0, zero, e64, m4 \n\t"
+ "vsse64.v v16, (s1), t6 \n\t"
+ "add a1, a1, %[STRIDE] \n\t"
+ "bge t2, %[BlkLen], LOOP%= \n\t"
+
+ "TAIL%=: \n\t"
+ "blez t2, QUIT%= \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vxor.vv v8, v8, v8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "vsetvli t0, t2, e32, m8 \n\t"
+ "sub t2, t2, t0 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vsetvli t0, t2, e32, m8 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "sub t2, t2, t2 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "jal x0, QUANTIZE%= \n\t"
+
+ "QUIT%=: \n\t"
+ : [SRC] "+r"(SRC)
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
+ }
+ } else if (BlkLen == 256) {
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
+ const float * SRC = A + row_index * CountK;
+ std::byte * DST = QuantA + row_index * sizeof(float);
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "li t6, 32 \n\t"
+ "addi t2, %[CountK], 0 \n\t"
+ "addi a1, %[DST], 0 \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+ "blt t2, %[BlkLen], TAIL%= \n\t"
+
+ "LOOP%=: \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v16, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v24, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], -768 \n\t"
+ "addi t2, t2, -256 \n\t"
+ "vfabs.v v0, v0 \n\t"
+ "vfabs.v v8, v8 \n\t"
+ "vfabs.v v16, v16 \n\t"
+ "vfabs.v v24, v24 \n\t"
+ "vfmax.vv v8, v0, v8 \n\t"
+ "vfmax.vv v24, v24, v16 \n\t"
+ "vfmax.vv v8, v8, v24 \n\t"
+ "vfredmax.vs v24, v8, v24 \n\t"
+ "vfmv.f.s f10, v24 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v16, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v24, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+
+ "QUANTIZE%=: \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (a1) \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vfmul.vf v0, v0, f11 \n\t"
+ "vfmul.vf v8, v8, f11 \n\t"
+ "vfmul.vf v16, v16, f11 \n\t"
+ "vfmul.vf v24, v24, f11 \n\t"
+ "vfcvt.x.f.v v0, v0 \n\t"
+ "vfcvt.x.f.v v8, v8 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v0, v0, zero \n\t"
+ "vnclip.wx v4, v8, zero \n\t"
+ "vnclip.wx v8, v16, zero \n\t"
+ "vnclip.wx v12, v24, zero \n\t"
+ "vsetvli t0, zero, e8, m4 \n\t"
+ "vnclip.wx v0, v0, zero \n\t"
+ "vnclip.wx v4, v8, zero \n\t"
+ "vsetvli t0, zero, e64, m8 \n\t"
+ "vsse64.v v0, (s1), t6 \n\t"
+ "add a1, a1, %[STRIDE] \n\t"
+ "bge t2, %[BlkLen], LOOP%= \n\t"
+
+ "TAIL%=: \n\t"
+ "blez t2, QUIT%= \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vxor.vv v8, v8, v8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "addi t1, t2, 0 \n\t"
+ "vsetvli t0, t1, e32, m8 \n\t"
+ "sub t1, t1, t0 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vsetvli t0, t1, e32, m8 \n\t"
+ "sub t1, t1, t0 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vsetvli t0, t1, e32, m8 \n\t"
+ "sub t1, t1, t0 \n\t"
+ "vle32.v v16, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vsetvli t0, t1, e32, m8 \n\t"
+ "vle32.v v24, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], -768 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfabs.v v0, v0 \n\t"
+ "vfabs.v v8, v8 \n\t"
+ "vfabs.v v16, v16 \n\t"
+ "vfabs.v v24, v24 \n\t"
+ "vfmax.vv v8, v0, v8 \n\t"
+ "vfmax.vv v24, v16, v24 \n\t"
+ "vfmax.vv v8, v8, v24 \n\t"
+ "vfredmax.vs v24, v8, v24 \n\t"
+ "vfmv.f.s f10, v24 \n\t"
+ "add s1, a1, %[OFFSET] \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (a1) \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vsetvli t0, zero, e64, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vsse64.v v0, (s1), t6 \n\t"
+
+ "TAIL_LOOP%=: \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vsetvli t0, t2, e32, m1 \n\t"
+ "sub t2, t2, t0 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 32 \n\t"
+ "vfmul.vf v1, v0, f11 \n\t"
+ "vfcvt.x.f.v v2, v1 \n\t"
+ "vsetvli t0, zero, e16, mf2 \n\t"
+ "vnclip.wx v3, v2, zero \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vnclip.wx v3, v3, zero \n\t"
+ "vse8.v v3, (s1) \n\t"
+ "addi s1, s1, 32 \n\t"
+ "bnez t2, TAIL_LOOP%= \n\t"
+
+ "QUIT%=: \n\t"
+ : [SRC] "+r"(SRC)
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
+ }
+ }
+}
+
+void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
+ const float * SRC = A;
+ std::byte * DST = QuantA;
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
+ const float fone = 1.0f;
+ std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
+ size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
+
+ if (CountK <= BlkLen) {
+ float max_abs_A = 0.0f;
+ for (size_t k = 0; k < CountK; k++) {
+ max_abs_A = std::max(max_abs_A, fabsf(A[k]));
+ }
+ float scale_A = max_abs_A * range_max_reciprocal;
+
+ ((float *) QuantA)[0] = scale_A;
+
+ auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
+
+ for (size_t k = 0; k < CountK; k++) {
+ QuantAData_offset[k] =
+ (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
+ (float) std::numeric_limits<int8_t>::max());
+ }
+ for (size_t k = CountK; k < BlkLen; k++) {
+ QuantAData_offset[k] = 0;
+ }
+
+ return;
+ }
+
+ if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
+ __asm__ volatile(
+ "vsetvli t0, zero, e8, m8 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "LOOP%=: \n\t"
+ "vsetvli t0, %[CNT], e8, m8 \n\t"
+ "vse8.v v24, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 128 \n\t"
+ "sub %[CNT], %[CNT], t0 \n\t"
+ "bnez %[CNT], LOOP%= \n\t"
+ : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
+ :
+ : "cc", "t0");
+ }
+ if (BlkLen == 16) {
+ float buffer[64] = { 0.0f };
+ __asm__ volatile(
+ "addi t3, zero, 16*8 \n\t"
+ "addi t2, zero, 16 \n\t"
+ "blt %[K], t3, LOOP_K%= \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+ "LOOP_MAIN%=: \n\t"
+ "vsetvli t1, zero, e32, m2 \n\t"
+ "addi %[K], %[K], -128 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v2, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v4, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v6, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v10, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v12, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "vle32.v v14, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "addi a1, %[BUFFER], 0 \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vfabs.v v18, v2 \n\t"
+ "vfabs.v v20, v4 \n\t"
+ "vfabs.v v22, v6 \n\t"
+ "vfabs.v v24, v8 \n\t"
+ "vfabs.v v26, v10 \n\t"
+ "vfabs.v v28, v12 \n\t"
+ "vfabs.v v30, v14 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v16, v16, v17 \n\t"
+ "vfmax.vv v18, v18, v19 \n\t"
+ "vfmax.vv v20, v20, v21 \n\t"
+ "vfmax.vv v22, v22, v23 \n\t"
+ "vfmax.vv v24, v24, v25 \n\t"
+ "vfmax.vv v26, v26, v27 \n\t"
+ "vfmax.vv v28, v28, v29 \n\t"
+ "vfmax.vv v30, v30, v31 \n\t"
+ "vse32.v v16, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v18, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v20, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v22, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v24, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v26, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v28, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vse32.v v30, (a1) \n\t"
+ "addi a1, %[BUFFER], 0 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f10, f3, f7 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f10, %[FONE], f10 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f11, f3, f7 \n\t"
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
+ "fsw f11, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f11, %[FONE], f11 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f12, f3, f7 \n\t"
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
+ "fsw f12, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f12, %[FONE], f12 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f13, f3, f7 \n\t"
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
+ "fsw f13, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f13, %[FONE], f13 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f14, f3, f7 \n\t"
+ "fmul.s f14, f14, %[RMAXREC] \n\t"
+ "fsw f14, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f14, %[FONE], f14 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f15, f3, f7 \n\t"
+ "fmul.s f15, f15, %[RMAXREC] \n\t"
+ "fsw f15, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f15, %[FONE], f15 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f16, f3, f7 \n\t"
+ "fmul.s f16, f16, %[RMAXREC] \n\t"
+ "fsw f16, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "fdiv.s f16, %[FONE], f16 \n\t"
+ "flw f0, (a1) \n\t"
+ "flw f1, 4(a1) \n\t"
+ "flw f2, 8(a1) \n\t"
+ "flw f3, 12(a1) \n\t"
+ "flw f4, 16(a1) \n\t"
+ "flw f5, 20(a1) \n\t"
+ "flw f6, 24(a1) \n\t"
+ "flw f7, 28(a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f17, f3, f7 \n\t"
+ "fmul.s f17, f17, %[RMAXREC] \n\t"
+ "fsw f17, (%[DST]) \n\t"
+ "addi %[DST], %[DST], -136 \n\t"
+ "fdiv.s f17, %[FONE], f17 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmul.vf v16, v0, f10 \n\t"
+ "vfmul.vf v18, v2, f11 \n\t"
+ "vfmul.vf v20, v4, f12 \n\t"
+ "vfmul.vf v22, v6, f13 \n\t"
+ "vfmul.vf v24, v8, f14 \n\t"
+ "vfmul.vf v26, v10, f15 \n\t"
+ "vfmul.vf v28, v12, f16 \n\t"
+ "vfmul.vf v30, v14, f17 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v18, v18 \n\t"
+ "vfcvt.x.f.v v20, v20 \n\t"
+ "vfcvt.x.f.v v22, v22 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vfcvt.x.f.v v26, v26 \n\t"
+ "vfcvt.x.f.v v28, v28 \n\t"
+ "vfcvt.x.f.v v30, v30 \n\t"
+ "vsetvli t0, zero, e16, m1 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v18, v18, zero \n\t"
+ "vnclip.wx v20, v20, zero \n\t"
+ "vnclip.wx v22, v22, zero \n\t"
+ "vnclip.wx v24, v24, zero \n\t"
+ "vnclip.wx v26, v26, zero \n\t"
+ "vnclip.wx v28, v28, zero \n\t"
+ "vnclip.wx v30, v30, zero \n\t"
+ "vsetvli t0, t1, e8, mf2 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v18, v18, zero \n\t"
+ "vnclip.wx v20, v20, zero \n\t"
+ "vnclip.wx v22, v22, zero \n\t"
+ "vnclip.wx v24, v24, zero \n\t"
+ "vnclip.wx v26, v26, zero \n\t"
+ "vnclip.wx v28, v28, zero \n\t"
+ "vnclip.wx v30, v30, zero \n\t"
+ "vse8.v v16, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v18, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v20, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v22, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v24, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v26, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v28, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 20 \n\t"
+ "vse8.v v30, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 16 \n\t"
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+ "LOOP_K%=: \n\t"
+ "vsetvli t1, %[K], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 64 \n\t"
+ "sub %[K], %[K], t1 \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v16, v16, v17 \n\t"
+ "vse32.v v16, (%[BUFFER]) \n\t"
+ "flw f0, (%[BUFFER]) \n\t"
+ "flw f1, 4(%[BUFFER]) \n\t"
+ "flw f2, 8(%[BUFFER]) \n\t"
+ "flw f3, 12(%[BUFFER]) \n\t"
+ "flw f4, 16(%[BUFFER]) \n\t"
+ "flw f5, 20(%[BUFFER]) \n\t"
+ "flw f6, 24(%[BUFFER]) \n\t"
+ "flw f7, 28(%[BUFFER]) \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f10, f3, f7 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 4 \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmul.vf v16, v0, f11 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vsetvli t0, zero, e16, m1 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vsetvli t0, t1, e8, mf2 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vse8.v v16, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 16 \n\t"
+ "bge %[K], t2, LOOP_K%= \n\t"
+ "TAIL%=: \n\t"
+ "blez %[K], END%= \n\t"
+ "vsetvli t0, t3, e32, m2 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "jal x0, LOOP_K%= \n\t"
+ "END%=: \n\t"
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
+ : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
+ "f13", "f14", "f15", "f16", "f17");
+ } else if (BlkLen == 32) {
+ __asm__ volatile(
+ "addi t3, zero, 32*4 \n\t"
+ "addi t2, zero, 32 \n\t"
+
+ "addi a1, %[SRC], 0 \n\t"
+ "addi a2, %[SRC], 128 \n\t"
+ "addi a3, %[SRC], 256 \n\t"
+ "addi a4, %[SRC], 384 \n\t"
+
+ "addi s1, %[DST], 0 \n\t"
+ "addi s2, %[DST], 36 \n\t"
+ "addi s3, %[DST], 72 \n\t"
+ "addi s4, %[DST], 108 \n\t"
+ "blt %[K], t3, LOOP_K%= \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+
+ "LOOP_MAIN%=: \n\t"
+ "vsetvli t1, zero, e32, m4 \n\t"
+ "addi %[K], %[K], -128 \n\t"
+ "vle32.v v0, (a1) \n\t"
+ "addi a1, a1, 512 \n\t"
+ "vle32.v v4, (a2) \n\t"
+ "addi a2, a2, 512 \n\t"
+ "vle32.v v8, (a3) \n\t"
+ "addi a3, a3, 512 \n\t"
+ "vle32.v v12, (a4) \n\t"
+ "addi a4, a4, 512 \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vfabs.v v20, v4 \n\t"
+ "vfabs.v v24, v8 \n\t"
+ "vfabs.v v28, v12 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v16, v16, v18 \n\t"
+ "vfmax.vv v20, v20, v22 \n\t"
+ "vfmax.vv v24, v24, v26 \n\t"
+ "vfmax.vv v28, v28, v30 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v16, v16, v17 \n\t"
+ "vfmax.vv v20, v20, v21 \n\t"
+ "vfmax.vv v24, v24, v25 \n\t"
+ "vfmax.vv v28, v28, v29 \n\t"
+
+ "vfredmax.vs v17, v16, v17 \n\t"
+ "vfredmax.vs v21, v20, v21 \n\t"
+ "vfredmax.vs v25, v24, v25 \n\t"
+ "vfredmax.vs v29, v28, v29 \n\t"
+ "vfmv.f.s f10, v17 \n\t"
+ "vfmv.f.s f11, v21 \n\t"
+ "vfmv.f.s f12, v25 \n\t"
+ "vfmv.f.s f13, v29 \n\t"
+
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
+ "fsw f10, (s1) \n\t"
+ "addi s1, s1, 4 \n\t"
+
+ "fsw f11, (s2) \n\t"
+ "addi s2, s2, 4 \n\t"
+ "fsw f12, (s3) \n\t"
+ "addi s3, s3, 4 \n\t"
+ "fsw f13, (s4) \n\t"
+ "addi s4, s4, 4 \n\t"
+ "fdiv.s f10, %[FONE], f10 \n\t"
+ "fdiv.s f11, %[FONE], f11 \n\t"
+ "fdiv.s f12, %[FONE], f12 \n\t"
+ "fdiv.s f13, %[FONE], f13 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vfmul.vf v16, v0, f10 \n\t"
+ "vfmul.vf v20, v4, f11 \n\t"
+ "vfmul.vf v24, v8, f12 \n\t"
+ "vfmul.vf v28, v12, f13 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v20, v20 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vfcvt.x.f.v v28, v28 \n\t"
+ "vsetvli t0, zero, e16, m2 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v20, v20, zero \n\t"
+ "vnclip.wx v24, v24, zero \n\t"
+ "vnclip.wx v28, v28, zero \n\t"
+ "vsetvli t0, t1, e8, m1 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v20, v20, zero \n\t"
+ "vnclip.wx v24, v24, zero \n\t"
+ "vnclip.wx v28, v28, zero \n\t"
+ "vse8.v v16, (s1) \n\t"
+ "addi s1, s1, 140 \n\t"
+ "vse8.v v20, (s2) \n\t"
+ "addi s2, s2, 140 \n\t"
+ "vse8.v v24, (s3) \n\t"
+ "addi s3, s3, 140 \n\t"
+ "vse8.v v28, (s4) \n\t"
+ "addi s4, s4, 140 \n\t"
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+ "LOOP_K%=: \n\t"
+ "vsetvli t1, %[K], e32, m4 \n\t"
+ "vle32.v v0, (a1) \n\t"
+ "addi a1, a1, 128 \n\t"
+ "sub %[K], %[K], t1 \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v16, v16, v18 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v16, v16, v17 \n\t"
+ "vfredmax.vs v17, v16, v17 \n\t"
+ "vfmv.f.s f10, v17 \n\t"
+
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (s1) \n\t"
+ "addi s1, s1, 4 \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vfmul.vf v16, v0, f11 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vsetvli t0, zero, e16, m2 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vse8.v v16, (s1) \n\t"
+ "addi s1, s1, 32 \n\t"
+ "bge %[K], t2, LOOP_K%= \n\t"
+ "TAIL%=: \n\t"
+ "blez %[K], END%= \n\t"
+ "vsetvli t0, t3, e32, m4 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "jal x0, LOOP_K%= \n\t"
+ "END%=: \n\t"
+ : [K] "+r"(CountK)
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
+ } else if (BlkLen == 64) {
+ __asm__ volatile(
+ "addi t3, zero, 64*2 \n\t"
+ "addi t2, zero, 64 \n\t"
+ "addi a1, %[SRC], 0 \n\t"
+ "addi a2, %[SRC], 256 \n\t"
+ "addi s1, %[DST], 0 \n\t"
+ "addi s2, %[DST], 68 \n\t"
+ "blt %[K], t3, LOOP_K%= \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+ "LOOP_MAIN%=: \n\t"
+ "vsetvli t1, zero, e32, m8 \n\t"
+ "addi %[K], %[K], -128 \n\t"
+ "vle32.v v0, (a1) \n\t"
+ "addi a1, a1, 512 \n\t"
+ "vle32.v v8, (a2) \n\t"
+ "addi a2, a2, 512 \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vfabs.v v24, v8 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vfmax.vv v16, v16, v20 \n\t"
+ "vfmax.vv v24, v24, v28 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v16, v16, v18 \n\t"
+ "vfmax.vv v24, v24, v26 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v16, v16, v17 \n\t"
+ "vfmax.vv v24, v24, v25 \n\t"
+ "vfredmax.vs v17, v16, v17 \n\t"
+ "vfredmax.vs v25, v24, v25 \n\t"
+ "vfmv.f.s f10, v17 \n\t"
+ "vfmv.f.s f11, v25 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
+ "fsw f10, (s1) \n\t"
+ "addi s1, s1, 4 \n\t"
+ "fsw f11, (s2) \n\t"
+ "addi s2, s2, 4 \n\t"
+ "fdiv.s f10, %[FONE], f10 \n\t"
+ "fdiv.s f11, %[FONE], f11 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfmul.vf v16, v0, f10 \n\t"
+ "vfmul.vf v24, v8, f11 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v24, v24, zero \n\t"
+ "vsetvli t0, t1, e8, m2 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v24, v24, zero \n\t"
+ "vse8.v v16, (s1) \n\t"
+ "addi s1, s1, 132 \n\t"
+ "vse8.v v24, (s2) \n\t"
+ "addi s2, s2, 132 \n\t"
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+ "LOOP_K%=: \n\t"
+ "vsetvli t1, %[K], e32, m8 \n\t"
+ "vle32.v v0, (a1) \n\t"
+ "addi a1, a1, 256 \n\t"
+ "sub %[K], %[K], t1 \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vfmax.vv v16, v16, v20 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v16, v16, v18 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v16, v16, v17 \n\t"
+ "vfredmax.vs v17, v16, v17 \n\t"
+ "vfmv.f.s f10, v17 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (s1) \n\t"
+ "addi s1, s1, 4 \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfmul.vf v16, v0, f11 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vsetvli t0, zero, e8, m2 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vse8.v v16, (s1) \n\t"
+ "addi s1, s1, 64 \n\t"
+ "bge %[K], t2, LOOP_K%= \n\t"
+ "TAIL%=: \n\t"
+ "blez %[K], END%= \n\t"
+ "vsetvli t0, t3, e32, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "jal x0, LOOP_K%= \n\t"
+ "END%=: \n\t"
+ : [K] "+r"(CountK)
+ : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
+ } else if (BlkLen == 128) {
+ __asm__ volatile(
+ "addi t2, zero, 128 \n\t"
+ "addi a1, %[SRC], 0 \n\t"
+ "addi a2, %[SRC], 256 \n\t"
+ "blt %[K], t2, TAIL%= \n\t"
+ "LOOP_K%=: \n\t"
+ "vsetvli t1, zero, e32, m8 \n\t"
+ "vle32.v v0, (a1) \n\t"
+ "addi a1, a1, 512 \n\t"
+ "vle32.v v8, (a2) \n\t"
+ "addi a2, a2, 512 \n\t"
+ "sub %[K], %[K], t2 \n\t"
+ "QUANT%=: \n\t"
+ "vfabs.v v16, v0 \n\t"
+ "vfabs.v v24, v8 \n\t"
+ "vfmax.vv v24, v16, v24 \n\t"
+ "vsetvli t1, zero, e32, m4 \n\t"
+ "vfmax.vv v28, v24, v28 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v30, v28, v30 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v30, v30, v31 \n\t"
+ "vfredmax.vs v31, v30, v31 \n\t"
+ "vfmv.f.s f10, v31 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 4 \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfmul.vf v16, v0, f11 \n\t"
+ "vfmul.vf v24, v8, f11 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vnclip.wx v20, v24, zero \n\t"
+ "vsetvli t0, zero, e8, m4 \n\t"
+ "vnclip.wx v16, v16, zero \n\t"
+ "vse8.v v16, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 128 \n\t"
+ "bge %[K], t2, LOOP_K%= \n\t"
+ "TAIL%=: \n\t"
+ "blez %[K], END%= \n\t"
+ "vsetvli t1, zero, e32, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vxor.vv v8, v8, v8 \n\t"
+ "vsetvli t0, %[K], e32, m8 \n\t"
+ "vle32.v v0, (a1) \n\t"
+ "sub %[K], %[K], t0 \n\t"
+ "vsetvli t0, %[K], e32, m8 \n\t"
+ "vle32.v v8, (a2) \n\t"
+ "sub %[K], %[K], t0 \n\t"
+ "vsetvli t1, zero, e32, m8 \n\t"
+ "jal x0, QUANT%= \n\t"
+ "END%=: \n\t"
+
+ : [DST] "+r"(DST), [K] "+r"(CountK)
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
+ : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
+ } else {
+ float buffer[8] = { 0.0f };
+ size_t cnt = BlkLen / 256;
+
+ __asm__ volatile(
+ "slli t3, %[BLK], 2 \n\t"
+ "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
+ "LOOP_MAIN%=: \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vxor.vv v31, v31, v31 \n\t"
+ "vse32.v v31, (%[BUFFER]) \n\t"
+ "addi t6, %[CNT], 0 \n\t"
+ "LOOP_CMP%=: \n\t"
+ "addi t6, t6, -1 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v16, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v24, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vfabs.v v0, v0 \n\t"
+ "vfabs.v v8, v8 \n\t"
+ "vfabs.v v16, v16 \n\t"
+ "vfabs.v v24, v24 \n\t"
+ "vfmax.vv v8, v0, v8 \n\t"
+ "vfmax.vv v16, v16, v24 \n\t"
+ "vfmax.vv v0, v0, v16 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vfmax.vv v0, v0, v4 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v0, v0, v2 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v0, v0, v1 \n\t"
+ "vle32.v v30, (%[BUFFER]) \n\t"
+ "vfmax.vv v31, v30, v0 \n\t"
+ "vse32.v v31, (%[BUFFER]) \n\t"
+ "bnez t6, LOOP_CMP%= \n\t"
+ "sub %[SRC], %[SRC], t3 \n\t"
+ "addi t6, %[CNT], 0 \n\t"
+ "flw f0, (%[BUFFER]) \n\t"
+ "flw f1, 4(%[BUFFER]) \n\t"
+ "flw f2, 8(%[BUFFER]) \n\t"
+ "flw f3, 12(%[BUFFER]) \n\t"
+ "flw f4, 16(%[BUFFER]) \n\t"
+ "flw f5, 20(%[BUFFER]) \n\t"
+ "flw f6, 24(%[BUFFER]) \n\t"
+ "flw f7, 28(%[BUFFER]) \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f10, f3, f7 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 4 \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "addi t6, %[CNT], 0 \n\t"
+ "LOOP_QUANT%=: \n\t"
+ "addi t6, t6, -1 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v8, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v16, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vle32.v v24, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfmul.vf v0, v0, f11 \n\t"
+ "vfmul.vf v8, v8, f11 \n\t"
+ "vfmul.vf v16, v16, f11 \n\t"
+ "vfmul.vf v24, v24, f11 \n\t"
+ "vfcvt.x.f.v v0, v0 \n\t"
+ "vfcvt.x.f.v v8, v8 \n\t"
+ "vfcvt.x.f.v v16, v16 \n\t"
+ "vfcvt.x.f.v v24, v24 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v0, v0, zero \n\t"
+ "vnclip.wx v4, v8, zero \n\t"
+ "vnclip.wx v8, v16, zero \n\t"
+ "vnclip.wx v12, v24, zero \n\t"
+ "vsetvli t0, zero, e8, m4 \n\t"
+ "vnclip.wx v0, v0, zero \n\t"
+ "vnclip.wx v4, v8, zero \n\t"
+ "vse8.v v0, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 128 \n\t"
+ "vse8.v v4, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 128 \n\t"
+ "bnez t6, LOOP_QUANT%= \n\t"
+ "sub %[K], %[K], %[BLK] \n\t"
+ "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
+ "blez %[K], END%= \n\t"
+ "LOOP_TAIL%=: \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vxor.vv v31, v31, v31 \n\t"
+ "vse32.v v31, (%[BUFFER]) \n\t"
+ "addi t6, %[K], 0 \n\t"
+ "addi s1, %[SRC], 0 \n\t"
+ "TAIL_CMP%=: \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vsetvli t0, t6, e32, m8 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi %[SRC], %[SRC], 256 \n\t"
+ "sub t6, t6, t0 \n\t"
+ "vfabs.v v0, v0 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vfmax.vv v0, v0, v4 \n\t"
+ "vsetvli t0, zero, e32, m2 \n\t"
+ "vfmax.vv v0, v0, v2 \n\t"
+ "vsetvli t0, zero, e32, m1 \n\t"
+ "vfmax.vv v0, v0, v1 \n\t"
+ "vle32.v v30, (%[BUFFER]) \n\t"
+ "vfmax.vv v31, v30, v0 \n\t"
+ "vse32.v v31, (%[BUFFER]) \n\t"
+ "bnez t6, TAIL_CMP%= \n\t"
+ "addi t6, %[K], 0 \n\t"
+ "flw f0, (%[BUFFER]) \n\t"
+ "flw f1, 4(%[BUFFER]) \n\t"
+ "flw f2, 8(%[BUFFER]) \n\t"
+ "flw f3, 12(%[BUFFER]) \n\t"
+ "flw f4, 16(%[BUFFER]) \n\t"
+ "flw f5, 20(%[BUFFER]) \n\t"
+ "flw f6, 24(%[BUFFER]) \n\t"
+ "flw f7, 28(%[BUFFER]) \n\t"
+ "fmax.s f1, f0, f1 \n\t"
+ "fmax.s f3, f2, f3 \n\t"
+ "fmax.s f5, f4, f5 \n\t"
+ "fmax.s f7, f6, f7 \n\t"
+ "fmax.s f3, f1, f3 \n\t"
+ "fmax.s f7, f5, f7 \n\t"
+ "fmax.s f10, f3, f7 \n\t"
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
+ "fsw f10, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 4 \n\t"
+ "fdiv.s f11, %[FONE], f10 \n\t"
+ "addi t6, %[K], 0 \n\t"
+ "TAIL_QUANT%=: \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v0, v0, v0 \n\t"
+ "vsetvli t1, t6, e32, m8 \n\t"
+ "vle32.v v0, (s1) \n\t"
+ "addi s1, s1, 256 \n\t"
+ "sub t6, t6, t1 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfmul.vf v0, v0, f11 \n\t"
+ "vfcvt.x.f.v v0, v0 \n\t"
+ "vsetvli t0, zero, e16, m4 \n\t"
+ "vnclip.wx v0, v0, zero \n\t"
+ "vsetvli t0, t1, e8, m2 \n\t"
+ "vnclip.wx v0, v0, zero \n\t"
+ "vse8.v v0, (%[DST]) \n\t"
+ "addi %[DST], %[DST], 64 \n\t"
+ "bnez t6, TAIL_QUANT%= \n\t"
+ "END%=: \n\t"
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
+ [CNT] "r"(cnt)
+ : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
+ }
+}
+
+} // namespace ime1
+
+namespace {
+#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
+ "vmadot v16, v14, v0 \n\t" \
+ "vmadot v18, v14, v1 \n\t" \
+ "vmadot v20, v14, v2 \n\t" \
+ "vmadot v22, v14, v3 \n\t" \
+ "vmadot v16, v15, v4 \n\t" \
+ "vmadot v18, v15, v5 \n\t" \
+ "vmadot v20, v15, v6 \n\t" \
+ "vmadot v22, v15, v7 \n\t"
+
+#define SQ4BIT_KERNEL_ACC_1X4X4 \
+ "vfcvt.f.x.v v16, v16 \n\t" \
+ "vfcvt.f.x.v v18, v18 \n\t" \
+ "vfcvt.f.x.v v20, v20 \n\t" \
+ "vfcvt.f.x.v v22, v22 \n\t" \
+ "addi s2, s1, 16 \n\t" \
+ "addi s3, s1, 32 \n\t" \
+ "addi s4, s1, 48 \n\t" \
+ "addi s6, s5, 12 \n\t" \
+ "vfmacc.vv v28, v16, v24 \n\t" \
+ "vfmacc.vv v29, v18, v25 \n\t" \
+ "vfmacc.vv v30, v20, v26 \n\t" \
+ "vfmacc.vv v31, v22, v27 \n\t"
+
+#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
+ "vfcvt.f.x.v v16, v16 \n\t" \
+ "vfcvt.f.x.v v18, v18 \n\t" \
+ "vfcvt.f.x.v v20, v20 \n\t" \
+ "vfcvt.f.x.v v22, v22 \n\t" \
+ "addi s2, s1, 8 \n\t" \
+ "addi s3, s1, 16 \n\t" \
+ "addi s4, s1, 24 \n\t" \
+ "addi s6, s5, 12 \n\t" \
+ "vfmacc.vv v28, v16, v24 \n\t" \
+ "vfmacc.vv v29, v18, v25 \n\t" \
+ "vfmacc.vv v30, v20, v26 \n\t" \
+ "vfmacc.vv v31, v22, v27 \n\t"
+
+#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
+ "vle8.v v4, (s1) \n\t" \
+ "addi s1, s1, 128 \n\t" \
+ "vle8.v v5, (s2) \n\t" \
+ "addi s2, s2, 128 \n\t" \
+ "vle8.v v6, (s3) \n\t" \
+ "addi s3, s3, 128 \n\t" \
+ "vle8.v v7, (s4) \n\t" \
+ "addi s4, s4, 128 \n\t" \
+ "vsetvli t0, zero, e8, mf4 \n\t" \
+ "vle8.v v14, (s5) \n\t" \
+ "addi s5, s5, 16 \n\t" \
+ "vle8.v v15, (s6) \n\t" \
+ "addi s6, s6, 16 \n\t" \
+ "addi t5, t5, -1 \n\t" \
+ "vsetvli t0, zero, e8, m1 \n\t" \
+ "vand.vi v0, v4, 15 \n\t" \
+ "vand.vi v1, v5, 15 \n\t" \
+ "vand.vi v2, v6, 15 \n\t" \
+ "vand.vi v3, v7, 15 \n\t" \
+ "vsrl.vi v4, v4, 4 \n\t" \
+ "vsrl.vi v5, v5, 4 \n\t" \
+ "vsrl.vi v6, v6, 4 \n\t" \
+ "vsrl.vi v7, v7, 4 \n\t"
+
+#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
+ "vsetvli t0, zero, e8, mf2 \n\t" \
+ "vle8.v v1, (s7) \n\t" \
+ "vsetvli t0, zero, e8, m1 \n\t" \
+ "vrgather.vv v8, v1, v13 \n\t" \
+ "vadd.vi v13, v13, 4 \n\t" \
+ "vrgather.vv v9, v1, v13 \n\t" \
+ "vadd.vi v13, v13, 4 \n\t" \
+ "vrgather.vv v10, v1, v13 \n\t" \
+ "vadd.vi v13, v13, 4 \n\t" \
+ "vrgather.vv v11, v1, v13 \n\t" \
+ "vadd.vi v13, v13, -12 \n\t"
+
+// using for M4Kernel
+#define LOAD_B_16x8x2 \
+ "vsetvli t0, zero, e8, m1 \n\t" \
+ "vle8.v v6, (s1) \n\t" \
+ "addi s1, s1, 32*4 \n\t" \
+ "vle8.v v7, (s2) \n\t" \
+ "addi s2, s2, 32*4 \n\t" \
+ "vle8.v v8, (s3) \n\t" \
+ "addi s3, s3, 32*4 \n\t" \
+ "vle8.v v9, (s4) \n\t" \
+ "addi s4, s4, 32*4 \n\t" \
+ \
+ "vand.vi v2, v6, 15 \n\t" \
+ "vand.vi v3, v7, 15 \n\t" \
+ "vand.vi v4, v8, 15 \n\t" \
+ "vand.vi v5, v9, 15 \n\t" \
+ \
+ "vsrl.vi v6, v6, 4 \n\t" \
+ "vsrl.vi v7, v7, 4 \n\t" \
+ "vsrl.vi v8, v8, 4 \n\t" \
+ "vsrl.vi v9, v9, 4 \n\t"
+
+// [s2|s5, s3, s4, s6]
+#define LOAD_SCALE_4x16_FP16 \
+ "addi s2, s5, -8 \n\t" \
+ "addi s3, s5, 8 \n\t" \
+ "addi s4, s5, 16 \n\t" \
+ "addi s6, s5, 24 \n\t" \
+ "li t1, 0xf0 \n\t" \
+ "vmv.s.x v0, t1 \n\t" \
+ "vsetvli t0, zero, e16, mf4 \n\t" \
+ "vle16.v v9, (s5) \n\t" \
+ "vle16.v v11, (s3) \n\t" \
+ "vle16.v v13, (s4) \n\t" \
+ "vle16.v v15, (s6) \n\t" \
+ "vsetvli t0, zero, e16, mf2 \n\t" \
+ "vle16.v v9, (s2), v0.t \n\t" \
+ "vle16.v v11, (s5), v0.t \n\t" \
+ "vle16.v v13, (s3), v0.t \n\t" \
+ "vle16.v v15, (s4), v0.t \n\t" \
+ "vfwcvt.f.f.v v8, v9 \n\t" \
+ "vfwcvt.f.f.v v10, v11 \n\t" \
+ "vfwcvt.f.f.v v12, v13 \n\t" \
+ "vfwcvt.f.f.v v14, v15 \n\t" \
+ "vsetvli t0, zero, e32, m1 \n\t" \
+ "vmv.v.v v9, v8 \n\t" \
+ "vmv.v.v v11, v10 \n\t" \
+ "vmv.v.v v13, v12 \n\t" \
+ "vmv.v.v v15, v14 \n\t" \
+ "li t1, 0xf0 \n\t" \
+ "vmv.s.x v0, t1 \n\t" \
+ "vsetvli t0, zero, e32, mf2 \n\t" \
+ "vfmul.vf v8, v8, f1 \n\t" \
+ "vfmul.vf v10, v10, f1 \n\t" \
+ "vfmul.vf v12, v12, f1 \n\t" \
+ "vfmul.vf v14, v14, f1 \n\t" \
+ "vfmul.vf v9, v9, f3 \n\t" \
+ "vfmul.vf v11, v11, f3 \n\t" \
+ "vfmul.vf v13, v13, f3 \n\t" \
+ "vfmul.vf v15, v15, f3 \n\t" \
+ "vsetvli t0, zero, e32, m1 \n\t" \
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
+
+// [s2|s5, s3, s4, s6]
+#define LOAD_SCALE_4x16 \
+ "addi s2, s5, -16 \n\t" \
+ "addi s3, s5, 16 \n\t" \
+ "addi s4, s5, 32 \n\t" \
+ "addi s6, s5, 48 \n\t" \
+ "li t1, 0xf0 \n\t" \
+ "vmv.s.x v0, t1 \n\t" \
+ "vsetvli t0, zero, e32, mf2 \n\t" \
+ "vle32.v v8, (s5) \n\t" \
+ "vle32.v v10, (s3) \n\t" \
+ "vle32.v v12, (s4) \n\t" \
+ "vle32.v v14, (s6) \n\t" \
+ "vsetvli t0, zero, e32, m1 \n\t" \
+ "vle32.v v8, (s2), v0.t \n\t" \
+ "vle32.v v10, (s5), v0.t \n\t" \
+ "vle32.v v12, (s3), v0.t \n\t" \
+ "vle32.v v14, (s4), v0.t \n\t" \
+ "vmv.v.v v9, v8 \n\t" \
+ "vmv.v.v v11, v10 \n\t" \
+ "vmv.v.v v13, v12 \n\t" \
+ "vmv.v.v v15, v14 \n\t" \
+ "vsetvli t0, zero, e32, mf2 \n\t" \
+ "vfmul.vf v8, v8, f1 \n\t" \
+ "vfmul.vf v10, v10, f1 \n\t" \
+ "vfmul.vf v12, v12, f1 \n\t" \
+ "vfmul.vf v14, v14, f1 \n\t" \
+ "vfmul.vf v9, v9, f3 \n\t" \
+ "vfmul.vf v11, v11, f3 \n\t" \
+ "vfmul.vf v13, v13, f3 \n\t" \
+ "vfmul.vf v15, v15, f3 \n\t" \
+ "vsetvli t0, zero, e32, m1 \n\t" \
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
+
+//[s1| BIAS, s2, s3, s4]
+#define LOAD_BIAS \
+ "vsetvli t0, zero, e32, mf2 \n\t" \
+ "li t1, 0xf0 \n\t" \
+ "vmv.s.x v0, t1 \n\t" \
+ "addi s1, %[BIAS], -16 \n\t" \
+ "addi s2, %[BIAS], 16 \n\t" \
+ "addi s3, %[BIAS], 32 \n\t" \
+ "addi s4, %[BIAS], 48 \n\t" \
+ \
+ "vle32.v v24, (%[BIAS]) \n\t" \
+ "vle32.v v26, (s2) \n\t" \
+ "vle32.v v28, (s3) \n\t" \
+ "vle32.v v30, (s4) \n\t" \
+ "vsetvli t0, zero, e32, m1 \n\t" \
+ "vle32.v v24, (s1), v0.t \n\t" \
+ "vle32.v v26, (%[BIAS]), v0.t \n\t" \
+ "vle32.v v28, (s2), v0.t \n\t" \
+ "vle32.v v30, (s3), v0.t \n\t" \
+ "vmv.v.v v25, v24 \n\t" \
+ "vmv.v.v v27, v26 \n\t" \
+ "vmv.v.v v29, v28 \n\t" \
+ "vmv.v.v v31, v30 \n\t"
+
+#define SQ4BIT_KERNEL_COMP_4x16x16 \
+ "vmadot v16, v10, v2 \n\t" \
+ "vmadot v18, v10, v3 \n\t" \
+ "vmadot v20, v10, v4 \n\t" \
+ "vmadot v22, v10, v5 \n\t" \
+ "vmadot v16, v11, v6 \n\t" \
+ "vmadot v18, v11, v7 \n\t" \
+ "vmadot v20, v11, v8 \n\t" \
+ "vmadot v22, v11, v9 \n\t"
+
+#define SAVE_RESULT_4x16 \
+ "addi a1, %[C], 0 \n\t" \
+ "add a2, %[C], %[LDC] \n\t" \
+ "add a3, a2, %[LDC] \n\t" \
+ "add a4, a3, %[LDC] \n\t" \
+ "addi a2, a2, -16 \n\t" \
+ "addi a4, a4, -16 \n\t" \
+ "li t1, 0xf0 \n\t" \
+ "vmv.s.x v0, t1 \n\t" \
+ "vsetvli t0, zero, e32, mf2 \n\t" \
+ \
+ "vse32.v v24, (a1) \n\t" \
+ "addi a1, a1, 16 \n\t" \
+ "vse32.v v25, (a3) \n\t" \
+ "addi a3, a3, 16 \n\t" \
+ \
+ "vse32.v v26, (a1) \n\t" \
+ "addi a1, a1, 16 \n\t" \
+ "vse32.v v27, (a3) \n\t" \
+ "addi a3, a3, 16 \n\t" \
+ \
+ "vse32.v v28, (a1) \n\t" \
+ "addi a1, a1, 16 \n\t" \
+ "vse32.v v29, (a3) \n\t" \
+ "addi a3, a3, 16 \n\t" \
+ \
+ "vse32.v v30, (a1) \n\t" \
+ "vse32.v v31, (a3) \n\t" \
+ "vsetvli t0, zero, e32, m1 \n\t" \
+ \
+ "vse32.v v24, (a2), v0.t \n\t" \
+ "addi a2, a2, 16 \n\t" \
+ "vse32.v v25, (a4), v0.t \n\t" \
+ "addi a4, a4, 16 \n\t" \
+ \
+ "vse32.v v26, (a2), v0.t \n\t" \
+ "addi a2, a2, 16 \n\t" \
+ "vse32.v v27, (a4), v0.t \n\t" \
+ "addi a4, a4, 16 \n\t" \
+ \
+ "vse32.v v28, (a2), v0.t \n\t" \
+ "addi a2, a2, 16 \n\t" \
+ "vse32.v v29, (a4), v0.t \n\t" \
+ "addi a4, a4, 16 \n\t" \
+ \
+ "vse32.v v30, (a2), v0.t \n\t" \
+ "vse32.v v31, (a4), v0.t \n\t"
+
+#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
+ "vsetvli t0, zero, e8, mf2 \n\t" \
+ "vle8.v v11, (s6) \n\t" \
+ "vsetvli t0, zero, e8, m1 \n\t" \
+ "vrgather.vv v12, v11, v1 \n\t" \
+ "vadd.vi v1, v1, 4 \n\t" \
+ "vrgather.vv v13, v11, v1 \n\t" \
+ "vadd.vi v1, v1, 4 \n\t" \
+ "vrgather.vv v14, v11, v1 \n\t" \
+ "vadd.vi v1, v1, 4 \n\t" \
+ "vrgather.vv v15, v11, v1 \n\t" \
+ "vadd.vi v1, v1, -12 \n\t"
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountN,
+ size_t BlockCountK,
+ const float * Bias,
+ const size_t ldc) {
+ GGML_UNUSED(QuantBScale);
+ GGML_UNUSED(QuantBZeroPoint);
+ size_t LDC = ldc * sizeof(float);
+ const size_t INNER = BlkLen / 16;
+ float tmp[4 * 16];
+
+ if constexpr (HasZeroPoint) {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(uint8_t) + // zp
+ n * BlockCountK * sizeof(_Float16); // scale
+ float * CPtr = C + n;
+ if (NBLKS < 16) {
+ CPtr = tmp;
+ LDC = 16 * sizeof(float);
+ }
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ if (NBLKS < 16) {
+ __asm__ volatile(
+ "vsetvli t0, %[N], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "vse32.v v0, (%[DST]) \n\t"
+ :
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+ : "cc", "t0");
+ bias = tmp;
+ }
+ __asm__ volatile(LOAD_BIAS
+
+ "addi t3, %[BlockCountK], 0 \n\t"
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "li s1, 24 \n\t"
+ "vmv.v.i v1, 3 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v1, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v1, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v1, 0 \n\t"
+
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ // scale offset
+ "addi s5, s1, 0 \n\t"
+ // zp offset
+ "addi s6, s1, 32 \n\t"
+ "addi s1, s6, 16 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vsub.vv v2, v2, v12 \n\t"
+ "vsub.vv v6, v6, v12 \n\t"
+ "vsub.vv v3, v3, v13 \n\t"
+ "vsub.vv v7, v7, v13 \n\t"
+ "vsub.vv v4, v4, v14 \n\t"
+ "vsub.vv v8, v8, v14 \n\t"
+ "vsub.vv v5, v5, v15 \n\t"
+ "vsub.vv v9, v9, v15 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16_FP16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+ "s2", "s3", "s4", "s5", "s6");
+
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "li s1, 24 \n\t"
+ "vmv.v.i v1, 3 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v1, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v1, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v1, 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ // scale offset
+ "addi s5, s1, 0 \n\t"
+ // zp offset
+ "addi s6, s1, 32 \n\t"
+ "addi s1, s6, 16 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vsub.vv v2, v2, v12 \n\t"
+ "vsub.vv v6, v6, v12 \n\t"
+ "vsub.vv v3, v3, v13 \n\t"
+ "vsub.vv v7, v7, v13 \n\t"
+ "vsub.vv v4, v4, v14 \n\t"
+ "vsub.vv v8, v8, v14 \n\t"
+ "vsub.vv v5, v5, v15 \n\t"
+ "vsub.vv v9, v9, v15 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16_FP16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+ "s4", "s5", "s6");
+ }
+ }
+ } else {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(_Float16); // scale
+ float * CPtr = C + n;
+ if (NBLKS < 16) {
+ CPtr = tmp;
+ LDC = 16 * sizeof(float);
+ }
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ if (NBLKS < 16) {
+ __asm__ volatile(
+ "vsetvli t0, %[N], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "vse32.v v0, (%[DST]) \n\t"
+ :
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+ : "cc", "t0");
+ bias = tmp;
+ }
+ __asm__ volatile(LOAD_BIAS
+
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ "addi s5, s1, 0 \n\t"
+ "addi s1, s5, 32 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+ "vadd.vi v8, v8, -8 \n\t"
+ "vadd.vi v9, v9, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16_FP16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+ "s2", "s3", "s4", "s5", "s6");
+
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ "addi s5, s1, 0 \n\t"
+ "addi s1, s5, 32 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+ "vadd.vi v8, v8, -8 \n\t"
+ "vadd.vi v9, v9, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16_FP16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+ "s4", "s5", "s6");
+ }
+ }
+ }
+ if (CountN % 16 != 0) {
+ // stroe output from tmp to C when NBLKS less than 16.
+ float * CPtr = C + CountN / 16 * 16;
+ const size_t N = CountN % 16;
+ LDC = ldc * sizeof(float);
+ __asm__ volatile(
+ "vsetvli t0, %[N], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi s2, %[SRC], 64 \n\t"
+ "addi s3, %[SRC], 64*2 \n\t"
+ "addi s4, %[SRC], 64*3 \n\t"
+ "vle32.v v2, (s2) \n\t"
+ "vle32.v v4, (s3) \n\t"
+ "vle32.v v6, (s4) \n\t"
+ "add t2, %[DST], %[LDC] \n\t"
+ "add t3, t2, %[LDC] \n\t"
+ "add t4, t3, %[LDC] \n\t"
+ "vse32.v v0, (%[DST]) \n\t"
+ "vse32.v v2, (t2) \n\t"
+ "vse32.v v4, (t3) \n\t"
+ "vse32.v v6, (t4) \n\t"
+ :
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
+ }
+}
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountN,
+ size_t BlockCountK,
+ const float * Bias,
+ const size_t ldc) {
+ GGML_UNUSED(QuantBScale);
+ GGML_UNUSED(QuantBZeroPoint);
+ size_t LDC = ldc * sizeof(float);
+ const size_t INNER = BlkLen / 16;
+ float tmp[4 * 16];
+
+ if constexpr (HasZeroPoint) {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(uint8_t) + // zp
+ n * BlockCountK * sizeof(float); // scale
+ float * CPtr = C + n;
+ if (NBLKS < 16) {
+ CPtr = tmp;
+ LDC = 16 * sizeof(float);
+ }
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ if (NBLKS < 16) {
+ __asm__ volatile(
+ "vsetvli t0, %[N], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "vse32.v v0, (%[DST]) \n\t"
+ :
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+ : "cc", "t0");
+ bias = tmp;
+ }
+
+ __asm__ volatile(LOAD_BIAS
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "li s1, 24 \n\t"
+ "vmv.v.i v1, 3 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v1, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v1, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v1, 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ // scale offset
+ "addi s5, s1, 0 \n\t"
+ // zp offset
+ "addi s6, s1, 64 \n\t"
+ "addi s1, s6, 16 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vsub.vv v2, v2, v12 \n\t"
+ "vsub.vv v6, v6, v12 \n\t"
+ "vsub.vv v3, v3, v13 \n\t"
+ "vsub.vv v7, v7, v13 \n\t"
+ "vsub.vv v4, v4, v14 \n\t"
+ "vsub.vv v8, v8, v14 \n\t"
+ "vsub.vv v5, v5, v15 \n\t"
+ "vsub.vv v9, v9, v15 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+ "s2", "s3", "s4", "s5", "s6");
+
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "li s1, 24 \n\t"
+ "vmv.v.i v1, 3 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v1, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v1, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v1, 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ // scale offset
+ "addi s5, s1, 0 \n\t"
+ // zp offset
+ "addi s6, s1, 64 \n\t"
+ "addi s1, s6, 16 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vsub.vv v2, v2, v12 \n\t"
+ "vsub.vv v6, v6, v12 \n\t"
+ "vsub.vv v3, v3, v13 \n\t"
+ "vsub.vv v7, v7, v13 \n\t"
+ "vsub.vv v4, v4, v14 \n\t"
+ "vsub.vv v8, v8, v14 \n\t"
+ "vsub.vv v5, v5, v15 \n\t"
+ "vsub.vv v9, v9, v15 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+ "s4", "s5", "s6");
+ }
+ }
+ } else {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(float); // scale
+ float * CPtr = C + n;
+ if (NBLKS < 16) {
+ CPtr = tmp;
+ LDC = 16 * sizeof(float);
+ }
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ if (NBLKS < 16) {
+ __asm__ volatile(
+ "vsetvli t0, %[N], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "vse32.v v0, (%[DST]) \n\t"
+ :
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+ : "cc", "t0");
+ bias = tmp;
+ }
+ __asm__ volatile(LOAD_BIAS
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ "addi s5, s1, 0 \n\t"
+ "addi s1, s5, 64 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vle8.v v10, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+ "vadd.vi v8, v8, -8 \n\t"
+ "vadd.vi v9, v9, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+ "s2", "s3", "s4", "s5", "s6");
+
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v24, v24, v24 \n\t"
+ "addi t3, %[BlockCountK], 0 \n\t"
+ "addi a1, %[A], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "BLOCK_COUNTK_LOOP%=: \n\t"
+ "addi s5, s1, 0 \n\t"
+ "addi s1, s5, 64 \n\t"
+ "addi s2, s1, 32 \n\t"
+ "addi s3, s1, 32*2 \n\t"
+ "addi s4, s1, 32*3 \n\t"
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ // load a scale
+ "flw f1, (a1) \n\t"
+ "flw f2, 4(a1) \n\t"
+ "flw f3, 8(a1) \n\t"
+ "flw f4, 12(a1) \n\t"
+ "addi a1, a1, 16 \n\t"
+ "addi t2, %[INNER], 0 \n\t"
+ "BLOCK_INNER_LOOP%=: \n\t"
+
+ LOAD_B_16x8x2
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vle8.v v10, (a1) \n\t"
+
+ "addi a1, a1, 32 \n\t"
+ "vle8.v v11, (a1) \n\t"
+ "addi a1, a1, 32 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+ "vadd.vi v8, v8, -8 \n\t"
+ "vadd.vi v9, v9, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_4x16x16
+
+ "addi t2, t2, -1 \n\t"
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
+
+ LOAD_SCALE_4x16
+
+ "vsetvli t0, zero, e32, m8 \n\t"
+ "vfcvt.f.x.v v16, v16 \n\t"
+ "vfmacc.vv v24, v16, v8 \n\t"
+ "addi t3, t3, -1 \n\t"
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
+
+ "RESULT_SAVE%=: \n\t"
+
+ SAVE_RESULT_4x16
+
+ :
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+ "s4", "s5", "s6");
+ }
+ }
+ }
+ if (CountN % 16 != 0) {
+ // stroe output from tmp to C when NBLKS less than 16.
+ float * CPtr = C + CountN / 16 * 16;
+ const size_t N = CountN % 16;
+ LDC = ldc * sizeof(float);
+ __asm__ volatile(
+ "vsetvli t0, %[N], e32, m2 \n\t"
+ "vle32.v v0, (%[SRC]) \n\t"
+ "addi s2, %[SRC], 64 \n\t"
+ "addi s3, %[SRC], 64*2 \n\t"
+ "addi s4, %[SRC], 64*3 \n\t"
+ "vle32.v v2, (s2) \n\t"
+ "vle32.v v4, (s3) \n\t"
+ "vle32.v v6, (s4) \n\t"
+ "add t2, %[DST], %[LDC] \n\t"
+ "add t3, t2, %[LDC] \n\t"
+ "add t4, t3, %[LDC] \n\t"
+ "vse32.v v0, (%[DST]) \n\t"
+ "vse32.v v2, (t2) \n\t"
+ "vse32.v v4, (t3) \n\t"
+ "vse32.v v6, (t4) \n\t"
+ :
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
+ }
+}
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountN,
+ size_t BlockCountK,
+ const float * Bias) {
+ GGML_UNUSED(QuantBScale);
+ GGML_UNUSED(QuantBZeroPoint);
+ size_t INNER = BlkLen / 16;
+
+ if constexpr (HasZeroPoint) {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(uint8_t) + // zp
+ n * BlockCountK * sizeof(_Float16); // scale
+ float * CPtr = C + n;
+ size_t cnt = BlockCountK;
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ __asm__ volatile(
+ "addi t3, %[NBLKS], 0 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+
+ "vmv.v.i v13, 3 \n\t"
+ "li s1, 24 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v13, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v13, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v13, 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 8 \n\t"
+ "addi s3, %[B], 16 \n\t"
+ "addi s4, %[B], 24 \n\t"
+ // zp offset
+ "addi s7, %[B], 32 \n\t"
+ // a offset
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v28, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v29, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v30, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v31, (%[BIAS]) \n\t"
+
+ "LOOP_K%=: \n\t"
+ "vsetvli t0, zero, e16, mf4 \n\t"
+
+ "vle16.v v4, (s1) \n\t"
+ "addi s1, s1, 48 \n\t"
+ "vle16.v v5, (s2) \n\t"
+ "addi s2, s2, 72 \n\t"
+ "vle16.v v6, (s3) \n\t"
+ "addi s3, s3, 96 \n\t"
+ "vle16.v v7, (s4) \n\t"
+ "addi s4, s4, 120 \n\t"
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+ "vfwcvt.f.f.v v8, v4 \n\t"
+ "vfwcvt.f.f.v v9, v5 \n\t"
+ "vfwcvt.f.f.v v10, v6 \n\t"
+ "vfwcvt.f.f.v v11, v7 \n\t"
+
+ "vsetvli t0, zero, e32, mf2 \n\t"
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vsub.vv v0, v0, v8 \n\t"
+ "vsub.vv v4, v4, v8 \n\t"
+ "vsub.vv v1, v1, v9 \n\t"
+ "vsub.vv v5, v5, v9 \n\t"
+ "vsub.vv v2, v2, v10 \n\t"
+ "vsub.vv v6, v6, v10 \n\t"
+ "vsub.vv v3, v3, v11 \n\t"
+ "vsub.vv v7, v7, v11 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
+ "addi s7, s1, 32 \n\t"
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vxor.vv v28, v28, v28 \n\t"
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vmv.v.i v13, 3 \n\t"
+ "li s1, 24 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v13, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v13, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v13, 0 \n\t"
+
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 8 \n\t"
+ "addi s3, %[B], 16 \n\t"
+ "addi s4, %[B], 24 \n\t"
+
+ "addi s7, %[B], 32 \n\t"
+
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+ "LOOP_K%=: \n\t"
+ "vsetvli t0, zero, e16, mf4 \n\t"
+ "vle16.v v4, (s1) \n\t"
+ "addi s1, s1, 48 \n\t"
+ "vle16.v v5, (s2) \n\t"
+ "addi s2, s2, 72 \n\t"
+ "vle16.v v6, (s3) \n\t"
+ "addi s3, s3, 96 \n\t"
+ "vle16.v v7, (s4) \n\t"
+ "addi s4, s4, 120 \n\t"
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+
+ "vfwcvt.f.f.v v8, v4 \n\t"
+ "vfwcvt.f.f.v v9, v5 \n\t"
+ "vfwcvt.f.f.v v10, v6 \n\t"
+ "vfwcvt.f.f.v v11, v7 \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vsub.vv v0, v0, v8 \n\t"
+ "vsub.vv v4, v4, v8 \n\t"
+ "vsub.vv v1, v1, v9 \n\t"
+ "vsub.vv v5, v5, v9 \n\t"
+ "vsub.vv v2, v2, v10 \n\t"
+ "vsub.vv v6, v6, v10 \n\t"
+ "vsub.vv v3, v3, v11 \n\t"
+ "vsub.vv v7, v7, v11 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
+ "addi s7, s1, 32 \n\t"
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+ }
+ }
+ } else {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(_Float16); // scale
+ float * CPtr = C + n;
+ size_t cnt = BlockCountK;
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ __asm__ volatile(
+ "addi t3, %[NBLKS], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 8 \n\t"
+ "addi s3, %[B], 16 \n\t"
+ "addi s4, %[B], 24 \n\t"
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v28, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v29, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v30, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v31, (%[BIAS]) \n\t"
+
+ "LOOP_K%=: \n\t"
+ "vsetvli t0, zero, e16, mf4 \n\t"
+
+ "vle16.v v4, (s1) \n\t"
+ "addi s1, s1, 32 \n\t"
+ "vle16.v v5, (s2) \n\t"
+ "addi s2, s2, 56 \n\t"
+ "vle16.v v6, (s3) \n\t"
+ "addi s3, s3, 80 \n\t"
+ "vle16.v v7, (s4) \n\t"
+ "addi s4, s4, 104 \n\t"
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+ "vfwcvt.f.f.v v8, v4 \n\t"
+ "vfwcvt.f.f.v v9, v5 \n\t"
+ "vfwcvt.f.f.v v10, v6 \n\t"
+ "vfwcvt.f.f.v v11, v7 \n\t"
+
+ "vsetvli t0, zero, e32, mf2 \n\t"
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vadd.vi v0, v0, -8 \n\t"
+ "vadd.vi v1, v1, -8 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vxor.vv v28, v28, v28 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 8 \n\t"
+ "addi s3, %[B], 16 \n\t"
+ "addi s4, %[B], 24 \n\t"
+
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+ "LOOP_K%=: \n\t"
+ "vsetvli t0, zero, e16, mf4 \n\t"
+ "vle16.v v4, (s1) \n\t"
+ "addi s1, s1, 32 \n\t"
+ "vle16.v v5, (s2) \n\t"
+ "addi s2, s2, 56 \n\t"
+ "vle16.v v6, (s3) \n\t"
+ "addi s3, s3, 80 \n\t"
+ "vle16.v v7, (s4) \n\t"
+ "addi s4, s4, 104 \n\t"
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+
+ "vfwcvt.f.f.v v8, v4 \n\t"
+ "vfwcvt.f.f.v v9, v5 \n\t"
+ "vfwcvt.f.f.v v10, v6 \n\t"
+ "vfwcvt.f.f.v v11, v7 \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vadd.vi v0, v0, -8 \n\t"
+ "vadd.vi v1, v1, -8 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+ }
+ }
+ }
+}
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountN,
+ size_t BlockCountK,
+ const float * Bias) {
+ GGML_UNUSED(QuantBScale);
+ GGML_UNUSED(QuantBZeroPoint);
+ const size_t INNER = BlkLen / 16;
+ if constexpr (HasZeroPoint) {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(uint8_t) + // zp
+ n * BlockCountK * sizeof(float); // scale
+ float * CPtr = C + n;
+ size_t cnt = BlockCountK;
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ __asm__ volatile(
+ "addi t3, %[NBLKS], 0 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vmv.v.i v13, 3 \n\t"
+ "li s1, 24 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v13, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v13, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v13, 0 \n\t"
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vxor.vv v28, v28, v28 \n\t"
+
+ // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 16 \n\t"
+ "addi s3, %[B], 32 \n\t"
+ "addi s4, %[B], 48 \n\t"
+ // zp offset
+ "addi s7, %[B], 64 \n\t"
+ // a offset
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v28, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v29, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v30, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v31, (%[BIAS]) \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+ "LOOP_K%=: \n\t"
+
+ // load scale
+ "vle32.v v8, (s1) \n\t"
+ "addi s1, s1, 80 \n\t"
+ "vle32.v v9, (s2) \n\t"
+ "addi s2, s2, 96 \n\t"
+ "vle32.v v10, (s3) \n\t"
+ "addi s3, s3, 112 \n\t"
+ "vle32.v v11, (s4) \n\t"
+ "addi s4, s4, 128 \n\t"
+
+ // load a scale
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+
+ // a scale * b scale
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vsub.vv v0, v0, v8 \n\t"
+ "vsub.vv v4, v4, v8 \n\t"
+ "vsub.vv v1, v1, v9 \n\t"
+ "vsub.vv v5, v5, v9 \n\t"
+ "vsub.vv v2, v2, v10 \n\t"
+ "vsub.vv v6, v6, v10 \n\t"
+ "vsub.vv v3, v3, v11 \n\t"
+ "vsub.vv v7, v7, v11 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_1X4X4
+ "addi s7, s1, 64 \n\t"
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vxor.vv v28, v28, v28 \n\t"
+
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "vmv.v.i v13, 3 \n\t"
+ "li s1, 24 \n\t"
+ "vsetvli t0, s1, e8, m1 \n\t"
+ "vmv.v.i v13, 2 \n\t"
+ "vsetvli t0, zero, e8, mf2 \n\t"
+ "vmv.v.i v13, 1 \n\t"
+ "vsetvli t0, zero, e8, mf4 \n\t"
+ "vmv.v.i v13, 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 16 \n\t"
+ "addi s3, %[B], 32 \n\t"
+ "addi s4, %[B], 48 \n\t"
+
+ "addi s7, %[B], 64 \n\t"
+
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ "LOOP_K%=: \n\t"
+ "vle32.v v8, (s1) \n\t"
+ "addi s1, s1, 80 \n\t"
+ "vle32.v v9, (s2) \n\t"
+ "addi s2, s2, 96 \n\t"
+ "vle32.v v10, (s3) \n\t"
+ "addi s3, s3, 112 \n\t"
+ "vle32.v v11, (s4) \n\t"
+ "addi s4, s4, 128 \n\t"
+
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vsub.vv v0, v0, v8 \n\t"
+ "vsub.vv v4, v4, v8 \n\t"
+ "vsub.vv v1, v1, v9 \n\t"
+ "vsub.vv v5, v5, v9 \n\t"
+ "vsub.vv v2, v2, v10 \n\t"
+ "vsub.vv v6, v6, v10 \n\t"
+ "vsub.vv v3, v3, v11 \n\t"
+ "vsub.vv v7, v7, v11 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_1X4X4
+ "addi s7, s1, 64 \n\t"
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+ }
+ }
+ } else {
+ for (size_t n = 0; n < CountN; n += 16) {
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
+ n * BlockCountK * BlkLen / 2 + // b data
+ n * BlockCountK * sizeof(float); // scale
+ float * CPtr = C + n;
+ size_t cnt = BlockCountK;
+ if (Bias != nullptr) {
+ const float * bias = Bias + n;
+ __asm__ volatile(
+ "addi t3, %[NBLKS], 0 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 16 \n\t"
+ "addi s3, %[B], 32 \n\t"
+ "addi s4, %[B], 48 \n\t"
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v28, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v29, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v30, (%[BIAS]) \n\t"
+ "sub t3, t3, t0 \n\t"
+ "addi %[BIAS], %[BIAS], 16 \n\t"
+ "vsetvli t0, t3, e32, mf2 \n\t"
+ "vle32.v v31, (%[BIAS]) \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+ "LOOP_K%=: \n\t"
+ "vle32.v v8, (s1) \n\t"
+ "addi s1, s1, 64 \n\t"
+ "vle32.v v9, (s2) \n\t"
+ "addi s2, s2, 80 \n\t"
+ "vle32.v v10, (s3) \n\t"
+ "addi s3, s3, 96 \n\t"
+ "vle32.v v11, (s4) \n\t"
+ "addi s4, s4, 112 \n\t"
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vadd.vi v0, v0, -8 \n\t"
+ "vadd.vi v1, v1, -8 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_1X4X4
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+ } else {
+ __asm__ volatile(
+ "vsetvli t0, zero, e32, m4 \n\t"
+ "vxor.vv v28, v28, v28 \n\t"
+ "addi s1, %[B], 0 \n\t"
+ "addi s2, %[B], 16 \n\t"
+ "addi s3, %[B], 32 \n\t"
+ "addi s4, %[B], 48 \n\t"
+
+ "addi s5, %[A], 0 \n\t"
+ "addi s6, %[A], 12 \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+ "LOOP_K%=: \n\t"
+ "vle32.v v8, (s1) \n\t"
+ "addi s1, s1, 64 \n\t"
+ "vle32.v v9, (s2) \n\t"
+ "addi s2, s2, 80 \n\t"
+ "vle32.v v10, (s3) \n\t"
+ "addi s3, s3, 96 \n\t"
+ "vle32.v v11, (s4) \n\t"
+ "addi s4, s4, 112 \n\t"
+ "flw f1, (s5) \n\t"
+ "addi s5, s5, 4 \n\t"
+
+ "addi t5, %[INNER], 0 \n\t"
+ "vxor.vv v16, v16, v16 \n\t"
+ "vxor.vv v18, v18, v18 \n\t"
+ "vxor.vv v20, v20, v20 \n\t"
+ "vxor.vv v22, v22, v22 \n\t"
+ "vfmul.vf v24, v8, f1 \n\t"
+ "vfmul.vf v25, v9, f1 \n\t"
+ "vfmul.vf v26, v10, f1 \n\t"
+ "vfmul.vf v27, v11, f1 \n\t"
+ "addi %[CNT], %[CNT], -1 \n\t"
+ "vsetvli t0, zero, e8, m1 \n\t"
+ "LOOP_INNER%=: \n\t"
+
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+ "vadd.vi v0, v0, -8 \n\t"
+ "vadd.vi v1, v1, -8 \n\t"
+ "vadd.vi v2, v2, -8 \n\t"
+ "vadd.vi v3, v3, -8 \n\t"
+ "vadd.vi v4, v4, -8 \n\t"
+ "vadd.vi v5, v5, -8 \n\t"
+ "vadd.vi v6, v6, -8 \n\t"
+ "vadd.vi v7, v7, -8 \n\t"
+
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+ "bnez t5, LOOP_INNER%= \n\t"
+ "vsetvli t0, zero, e32, mf2 \n\t"
+
+ SQ4BIT_KERNEL_ACC_1X4X4
+
+ "bnez %[CNT], LOOP_K%= \n\t"
+ "addi t3, zero, 16 \n\t"
+ "addi s1, %[C], 16 \n\t"
+ "addi s2, %[C], 32 \n\t"
+ "addi s3, %[C], 48 \n\t"
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "jal x0, END%= \n\t"
+
+ "ST_TAIL%=: \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v28, (%[C]) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v29, (s1) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v30, (s2) \n\t"
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
+ "vse32.v v31, (s3) \n\t"
+ "END%=: \n\t"
+
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+ }
+ }
+ }
+}
+
+template <bool HasZeroPoint>
+inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountM,
+ size_t CountN,
+ size_t BlockStrideQuantB,
+ const float * Bias,
+ const size_t ldc,
+ const size_t scalestride) {
+ if (scalestride == 4) {
+ SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
+ CountN, BlockStrideQuantB, Bias, ldc);
+
+ } else if (scalestride == 2) {
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
+ BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
+ }
+}
+
+template <bool HasZeroPoint>
+inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountM,
+ size_t CountN,
+ size_t BlockStrideQuantB,
+ const float * Bias,
+ const size_t ldc,
+ const size_t scalestride) {
+ if (scalestride == 4) {
+ SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
+ CountN, BlockStrideQuantB, Bias);
+ } else if (scalestride == 2) {
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
+ QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
+ }
+}
+
+} // namespace
+
+namespace ime1 {
+size_t gemm_kernel_i8i4(size_t BlkLen,
+ const std::byte * QuantA,
+ const std::byte * QuantBData,
+ const float * QuantBScale,
+ const std::byte * QuantBZeroPoint,
+ float * C,
+ size_t CountM,
+ size_t CountN,
+ size_t CountK,
+ size_t BlockCountK,
+ size_t ldc,
+ const float * Bias,
+ const size_t ScaleStride) {
+ GGML_UNUSED(CountM);
+ GGML_UNUSED(CountK);
+ GGML_UNUSED(ldc);
+ if (CountM >= 4) {
+ if (QuantBZeroPoint != nullptr) {
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
+ } else {
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
+ ldc, ScaleStride);
+ }
+ return 4;
+ } else {
+ if (QuantBZeroPoint != nullptr) {
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
+ } else {
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
+ ldc, ScaleStride);
+ }
+ return 1;
+ }
+}
+} // namespace ime1
+} // namespace sqnbitgemm_spacemit_ime
diff --git a/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h
new file mode 100644
index 0000000..7570634
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include <cstddef>
+
+namespace sqnbitgemm_spacemit_ime {
+namespace ime1 {
+size_t gemm_kernel_i8i4(size_t blk_len,
+ const std::byte * quant_a_ptr,
+ const std::byte * quant_b_data,
+ const float * quant_b_scale,
+ const std::byte * quant_b_zp,
+ float * c_ptr,
+ size_t count_m,
+ size_t count_n,
+ size_t count_k,
+ size_t block_count_k,
+ size_t ldc,
+ const float * bias,
+ const size_t scale_stride);
+
+void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
+
+void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
+
+} // namespace ime1
+} // namespace sqnbitgemm_spacemit_ime