1#define GGML_COMMON_IMPL_CPP
   2#define GGML_COMMON_DECL_CPP
   3
   4#include "ime.h"
   5
   6#include "ggml-backend-impl.h"
   7#include "ggml-common.h"
   8#include "ggml-cpu.h"
   9#include "ime_kernels.h"
  10#include "traits.h"
  11
  12#include <algorithm>
  13#include <cassert>
  14#include <cmath>
  15#include <cstdio>  // for GGML_ASSERT
  16#include <stdexcept>
  17#include <thread>
  18
  19// clang-format off
  20#if defined(__riscv)
  21
  22#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
  23#error "riscv v extension or v_intrinsic not enabled"
  24#else
  25#include <riscv_vector.h>
  26#endif
  27
  28#if !defined(__riscv_zfh)
  29#error "riscv zfh extension not enabled"
  30#endif
  31
  32#if defined(RISCV64_SPACEMIT_IME1)
  33#else
  34#error "RISCV64_SPACEMIT_IME1 not defined"
  35#endif
  36
  37#else
  38
  39#error "riscv not enabled in this build"
  40
  41#endif
  42
  43#if defined(__GNUC__)
  44#pragma GCC diagnostic ignored "-Woverlength-strings"
  45#pragma GCC diagnostic ignored "-Wcast-qual"
  46#pragma GCC diagnostic ignored "-Wunused-parameter"
  47#endif
  48
  49#if defined(RISCV64_SPACEMIT_IME1)
  50#define QGEMM_STRIDEN_THREAD_ALIGN 16
  51#else
  52#define QGEMM_STRIDEN_THREAD_ALIGN 32
  53#endif
  54
  55// clang-format on
  56
  57struct qnbitgemm_spacemit_ime_args {
  58    const float *     a_ptr               = nullptr;
  59    size_t            lda                 = 0;
  60    const std::byte * packed_quant_b_data = nullptr;
  61    const float *     quant_b_scale       = nullptr;
  62    const void *      quant_b_zp          = nullptr;
  63    const float *     quant_b_blksum      = nullptr;
  64    const float *     bias                = nullptr;
  65    float *           c_ptr               = nullptr;
  66    size_t            ldc                 = 0;
  67};
  68
  69constexpr size_t div_round_up(size_t up, size_t down) {
  70    return (up + down - 1) / down;
  71}
  72
  73constexpr size_t q8_blk_size(size_t blk_len) {
  74    const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
  75    // Currently, the strictest alignment requirement of a block is for a float.
  76    // Ensure contiguous blocks are suitably aligned.
  77    assert(blk_size % alignof(float) == 0);
  78    return blk_size;
  79}
  80
  81namespace ggml::cpu::riscv64_spacemit {
  82
  83const int num_ai_cores = std::thread::hardware_concurrency() / 2;
  84
  85}  // namespace ggml::cpu::riscv64_spacemit
  86
  87static void sqnbitgemm_spacemit_ime_i8i4(const size_t                        blk_len,
  88                                         const size_t                        gemm_k,
  89                                         const qnbitgemm_spacemit_ime_args * gemm_args,
  90                                         void * const                        per_gemm_ws,
  91                                         const size_t                        m_start,
  92                                         const size_t                        m_count,
  93                                         const size_t                        n_start,
  94                                         const size_t                        n_count) {
  95    constexpr size_t scale_stride = sizeof(uint16_t);
  96    constexpr size_t blk_bitwidth = 4;
  97
  98    const size_t k_blks = div_round_up(gemm_k, blk_len);
  99
 100    const size_t      lda         = k_blks * q8_blk_size(blk_len);
 101    const size_t      ldc         = gemm_args->ldc;
 102    const size_t      ldb         = k_blks * (blk_len * blk_bitwidth / 8);
 103    const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;
 104
 105    const size_t      zero_point_stride   = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;
 106    const size_t      packed_b_stride     = ldb + k_blks * (scale_stride + zero_point_stride);
 107    const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;
 108
 109    float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;
 110
 111    size_t       count_n               = 0;
 112    const size_t compute_block_count_n = m_count == 1 ? n_count : 16;
 113    for (size_t n = 0; n < n_count; n += count_n) {
 114        count_n = std::min(n_count - n, compute_block_count_n);
 115
 116        const std::byte * a_row    = quant_a_ptr;
 117        const std::byte * b_col    = packed_quant_b_data + n * packed_b_stride;
 118        const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
 119        float *           c_blk    = c_ptr + n;
 120
 121        int32_t rows_remaining = m_count;
 122
 123        while (rows_remaining > 0) {
 124            const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(
 125                blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
 126                scale_stride);
 127
 128            c_blk += rows_handled * ldc;
 129            a_row += rows_handled * lda;
 130
 131            rows_remaining -= rows_handled;
 132        }
 133    }
 134}
 135
 136template <int K> constexpr int QK_0() {
 137    if constexpr (K == 4) {
 138        return QK4_0;
 139    }
 140    if constexpr (K == 8) {
 141        return QK8_0;
 142    }
 143    return -1;
 144}
 145
 146template <int K, int N> struct block {
 147    ggml_half d[N];                         // deltas for N qK_0 blocks
 148    uint8_t   qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_0 blocks
 149};
 150
 151template <int K, int N> struct block_with_zp {
 152    ggml_half d[N];                         // deltas for N qK_1 blocks
 153    uint8_t   zp[N];                        // zero points for N qK_1 blocks
 154    uint8_t   qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_1 blocks
 155};
 156
 157// control size
 158static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
 159static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
 160              "wrong block_with_zp<4,16> size/padding");
 161static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
 162
 163using block_q4_0x16 = block<4, 16>;
 164using block_q4_1x16 = block_with_zp<4, 16>;
 165using block_q8_0x16 = block<8, 16>;
 166
 167static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
 168    block_q4_0x16 out;
 169    GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
 170
 171    for (int i = 0; i < 16; i++) {
 172        out.d[i] = in[i].d;
 173    }
 174
 175    for (int i = 0; i < 16; i++) {
 176        // [0, 15], in.d & 0x0F
 177        for (int j = 0; j < QK4_0 / 4; j++) {
 178            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
 179            //dst [b0 b8] ......... [b7 b15]
 180            out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
 181        }
 182    }
 183
 184    for (int i = 0; i < 16; i++) {
 185        // [16, 31], in.d & 0xF0
 186        for (int j = 0; j < QK4_0 / 4; j++) {
 187            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
 188            //dst [b16 b24] ......... [b23 b31]
 189            out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
 190        }
 191    }
 192
 193    return out;
 194}
 195
 196static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {
 197    block_q4_1x16 out;
 198    GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
 199
 200    for (int i = 0; i < 16; i++) {
 201        float d   = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
 202        float m   = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
 203        float mid = -std::nearbyintf(m / d);
 204        mid       = std::min(15.0f, std::max(0.0f, mid));
 205        out.d[i]  = GGML_FP32_TO_FP16(d);
 206        out.zp[i] = static_cast<uint8_t>(mid);
 207    }
 208
 209    for (int i = 0; i < 16; i++) {
 210        // [0, 15], in.d & 0x0F
 211        for (int j = 0; j < QK4_1 / 4; j++) {
 212            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
 213            //dst [b0 b8] ......... [b7 b15]
 214            out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
 215        }
 216    }
 217
 218    for (int i = 0; i < 16; i++) {
 219        // [16, 31], in.d & 0xF0
 220        for (int j = 0; j < QK4_1 / 4; j++) {
 221            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
 222            //dst [b16 b24] ......... [b23 b31]
 223            out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
 224        }
 225    }
 226
 227    return out;
 228}
 229
 230static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor *       t,
 231                                     int                        interleave_block,
 232                                     const void * GGML_RESTRICT data,
 233                                     size_t                     data_size) {
 234    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
 235    GGML_ASSERT(interleave_block == 16);
 236
 237    constexpr int nrows_interleaved = 16;
 238
 239    block_q4_0x16 *    dst = (block_q4_0x16 *) t->data;
 240    const block_q4_0 * src = (const block_q4_0 *) data;
 241    block_q4_0         dst_tmp[16];
 242    int                nrow    = ggml_nrows(t);
 243    int                nblocks = t->ne[0] / QK4_0;
 244
 245    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
 246
 247    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
 248        return -1;
 249    }
 250
 251    for (int b = 0; b < nrow; b += nrows_interleaved) {
 252        for (int64_t x = 0; x < nblocks; x++) {
 253            for (int i = 0; i < nrows_interleaved; i++) {
 254                dst_tmp[i] = src[x + i * nblocks];
 255            }
 256            *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
 257        }
 258        src += nrows_interleaved * nblocks;
 259    }
 260    return 0;
 261
 262    GGML_UNUSED(data_size);
 263}
 264
 265static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor *       t,
 266                                     int                        interleave_block,
 267                                     const void * GGML_RESTRICT data,
 268                                     size_t                     data_size) {
 269    GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
 270    GGML_ASSERT(interleave_block == 16);
 271
 272    constexpr int nrows_interleaved = 16;
 273
 274    block_q4_1x16 *    dst = (block_q4_1x16 *) t->data;
 275    const block_q4_1 * src = (const block_q4_1 *) data;
 276    block_q4_1         dst_tmp[16];
 277    int                nrow    = ggml_nrows(t);
 278    int                nblocks = t->ne[0] / QK4_1;
 279
 280    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));
 281
 282    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {
 283        return -1;
 284    }
 285
 286    for (int b = 0; b < nrow; b += nrows_interleaved) {
 287        for (int64_t x = 0; x < nblocks; x++) {
 288            for (int i = 0; i < nrows_interleaved; i++) {
 289                dst_tmp[i] = src[x + i * nblocks];
 290            }
 291            *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
 292        }
 293        src += nrows_interleaved * nblocks;
 294    }
 295    return 0;
 296
 297    GGML_UNUSED(data_size);
 298}
 299
 300static inline void get_scale_min_k4(int                           j,
 301                                    const uint8_t * GGML_RESTRICT q,
 302                                    uint8_t * GGML_RESTRICT       d,
 303                                    uint8_t * GGML_RESTRICT       m) {
 304    if (j < 4) {
 305        *d = q[j] & 63;
 306        *m = q[j + 4] & 63;
 307    } else {
 308        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
 309        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
 310    }
 311}
 312
 313static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor *       t,
 314                                     int                        interleave_block,
 315                                     const void * GGML_RESTRICT data,
 316                                     size_t                     data_size) {
 317    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
 318    GGML_ASSERT(interleave_block == 16);
 319    GGML_ASSERT(QK_K / QK4_1 == 8);
 320
 321    constexpr int nrows_interleaved = 16;
 322
 323    block_q4_1x16 *    dst = (block_q4_1x16 *) t->data;
 324    const block_q4_K * src = (const block_q4_K *) data;
 325    block_q4_1         dst_tmp[16];
 326    int                nrow    = ggml_nrows(t);
 327    int                nblocks = t->ne[0] / QK_K;
 328
 329    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {
 330        return -1;
 331    }
 332
 333    for (int b = 0; b < nrow; b += nrows_interleaved) {
 334        for (int64_t x = 0; x < nblocks; x++) {
 335            for (int j = 0; j < 8; j++) {
 336                for (int i = 0; i < nrows_interleaved; i++) {
 337                    uint8_t     sc, m;
 338                    const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
 339                    const float min =
 340                        GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);
 341                    get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);
 342                    const float d1 = d * sc;
 343                    const float m1 = min * m;
 344
 345                    dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);
 346                    dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);
 347                    // src -> [b0, b32] [b1, b33] ... [b31, b63]
 348                    // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]
 349                    const uint8_t * q                                  = src[x + i * nblocks].qs + (j / 2) * QK4_1;
 350                    if (j % 2 == 0) {
 351                        for (int ii = 0; ii < 16; ii++) {
 352                            dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
 353                        }
 354                    } else {
 355                        for (int ii = 0; ii < 16; ii++) {
 356                            dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
 357                        }
 358                    }
 359                }
 360                *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
 361            }
 362        }
 363        src += nrows_interleaved * nblocks;
 364    }
 365    return 0;
 366
 367    GGML_UNUSED(data_size);
 368}
 369
 370namespace ggml::cpu::riscv64_spacemit {
 371
 372template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
 373int repack(struct ggml_tensor *, const void *, size_t);
 374
 375template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
 376    return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
 377}
 378
 379template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
 380    return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);
 381}
 382
 383template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
 384    return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);
 385}
 386
 387class tensor_traits_base : public ggml::cpu::tensor_traits {
 388  public:
 389    virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
 390};
 391
 392template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
 393    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
 394        switch (op->op) {
 395            case GGML_OP_MUL_MAT:
 396                size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
 397                size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
 398                return true;
 399            default:
 400                // GGML_ABORT("fatal error");
 401                break;
 402        }
 403        return false;
 404    }
 405
 406    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
 407        switch (op->op) {
 408            case GGML_OP_MUL_MAT:
 409                if (op->src[0]->type == GGML_TYPE_Q4_0 ||  //
 410                    op->src[0]->type == GGML_TYPE_Q4_1 ||  //
 411                    op->src[0]->type == GGML_TYPE_Q4_K) {
 412                    forward_mul_mat_q4(params, op);
 413                    return true;
 414                }
 415            default:
 416                // GGML_ABORT("fatal error");
 417                break;
 418        }
 419        return false;
 420    }
 421
 422    void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {
 423        const ggml_tensor * src0 = op->src[0];
 424        const ggml_tensor * src1 = op->src[1];
 425        ggml_tensor *       dst  = op;
 426
 427        GGML_TENSOR_BINARY_OP_LOCALS
 428
 429        int ith = params->ith;
 430        int nth = params->nth;
 431
 432        [[maybe_unused]] const enum ggml_type type = src0->type;
 433
 434        void *        w_data  = (void *) src0->data;
 435        const float * feature = (const float *) src1->data;
 436        float *       output  = (float *) dst->data;
 437
 438        const size_t                  batch_feature = ne12 * ne13;
 439        [[maybe_unused]] const size_t batch_weight  = ne02 * ne03;
 440        const size_t                  gemm_m        = ne11;
 441        const size_t                  gemm_k        = ne10;
 442        const size_t                  gemm_n        = ne01;
 443
 444        GGML_ASSERT(batch_weight == 1);
 445
 446        const size_t block_count_k           = div_round_up(gemm_k, QK4_0);
 447        const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);
 448        const size_t per_gemm_workspace_stride =
 449            div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
 450        const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
 451        const size_t desired_wsize       = gemm_workspace_size + alignof(uint64_t) - 1;
 452
 453        if (ith == 0 && params->wsize < desired_wsize) {
 454            throw std::runtime_error("wsize less than desired_wsize");
 455        }
 456
 457        std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);
 458
 459        for (size_t i = 0; i < batch_feature; i++) {
 460            qnbitgemm_args[i].a_ptr               = feature + gemm_m * gemm_k * i;
 461            qnbitgemm_args[i].lda                 = gemm_k;
 462            qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;
 463            qnbitgemm_args[i].quant_b_scale       = nullptr;
 464
 465            if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
 466                qnbitgemm_args[i].quant_b_zp = nullptr;
 467            } else {
 468                qnbitgemm_args[i].quant_b_zp = w_data;
 469            }
 470
 471            qnbitgemm_args[i].bias  = nullptr;
 472            qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;
 473            qnbitgemm_args[i].ldc   = gemm_n;
 474        }
 475
 476        const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
 477        void *          ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
 478        const size_t    quant_a_stride = block_count_k * q8_blk_size(QK4_0);
 479
 480        {
 481            constexpr size_t block_size_m           = 4;
 482            size_t           per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);
 483            int32_t          task_count             = batch_feature * per_gemm_block_count_m;
 484            int32_t          task_per_thread        = (task_count + nth - 1) / nth;
 485            int32_t          start                  = ith * task_per_thread;
 486            int32_t          end                    = std::min((ith + 1) * task_per_thread, task_count);
 487            for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
 488                int32_t                             gemm_idx = compute_idx / per_gemm_block_count_m;
 489                int32_t                             block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
 490                int32_t                             m_idx    = block_idx_in_gemm * block_size_m;
 491                const qnbitgemm_spacemit_ime_args & data     = qnbitgemm_args[gemm_idx];
 492                int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
 493
 494                if (rows_tobe_handled == block_size_m) {
 495                    const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
 496                    std::byte *   quant_a_row_ptr =
 497                        static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
 498                    sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
 499                } else {
 500                    while (rows_tobe_handled) {
 501                        const float * a_row_ptr       = data.a_ptr + m_idx * data.lda;
 502                        std::byte *   quant_a_row_ptr = static_cast<std::byte *>(ws) +
 503                                                      gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
 504                        sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
 505                        rows_tobe_handled -= 1;
 506                        m_idx += 1;
 507                    }
 508                }
 509            }
 510        }
 511
 512        ggml_barrier(params->threadpool);
 513
 514        if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {
 515            return;
 516        }
 517        nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
 518
 519        size_t           threads_per_gemm = nth / batch_feature;
 520        constexpr size_t gemm_m_stride    = 128;
 521        size_t           nc               = gemm_n;
 522        const size_t     gemm_m_blocked   = div_round_up(gemm_m, gemm_m_stride);
 523        const size_t     max_nc           = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
 524        if (max_nc < nc) {
 525            nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
 526        }
 527        const size_t gemm_n_stride  = nc;
 528        const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
 529        const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
 530        threads_per_gemm            = thread_count_m * thread_count_n;
 531
 532        {
 533            int task_count      = batch_feature * threads_per_gemm;
 534            int task_per_thread = (task_count + nth - 1) / nth;
 535            int start           = ith * task_per_thread;
 536            int end             = std::min((ith + 1) * task_per_thread, task_count);
 537            for (int compute_idx = start; compute_idx < end; compute_idx++) {
 538                const auto   gemm_i = compute_idx / threads_per_gemm;
 539                const auto   blk_i  = compute_idx % threads_per_gemm;
 540                const auto * data   = &qnbitgemm_args[gemm_i];
 541
 542                const auto tid_n = blk_i / thread_count_m;
 543                const auto tid_m = blk_i % thread_count_m;
 544
 545                const size_t m_start = tid_m * gemm_m_stride;
 546                const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);
 547
 548                const size_t n_start = tid_n * gemm_n_stride;
 549                const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
 550
 551                void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;
 552
 553                sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);
 554            }
 555        }
 556    }
 557
 558    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
 559        GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
 560                       (int) NB_COLS, (int) INTER_SIZE);
 561        return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
 562    }
 563};
 564
 565class tensor_traits_common : public tensor_traits_base {
 566    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
 567        switch (op->op) {
 568            case GGML_OP_NORM:
 569            case GGML_OP_RMS_NORM:
 570                size = 0;
 571                return true;
 572            default:
 573                // GGML_ABORT("fatal error");
 574                break;
 575        }
 576        return false;
 577    }
 578
 579    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
 580        switch (op->op) {
 581            case GGML_OP_NORM:
 582                forward_norm_f32(params, op);
 583                return true;
 584            case GGML_OP_RMS_NORM:
 585                forward_rms_norm_f32(params, op);
 586                return true;
 587            default:
 588                // GGML_ABORT("fatal error");
 589                break;
 590        }
 591        return false;
 592    }
 593
 594    void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
 595        const ggml_tensor * src0 = op->src[0];
 596        ggml_tensor *       dst  = op;
 597        GGML_ASSERT(ggml_are_same_shape(src0, dst));
 598        GGML_ASSERT(src0->nb[0] == sizeof(float));
 599
 600        const int ith = params->ith;
 601        const int nth = params->nth;
 602
 603        GGML_TENSOR_UNARY_OP_LOCALS
 604
 605        float epsilon;
 606        memcpy(&epsilon, dst->op_params, sizeof(float));
 607
 608        GGML_ASSERT(epsilon > 0.0f);
 609
 610        auto * input  = (float *) src0->data;
 611        auto * output = (float *) dst->data;
 612
 613        const auto hidden_size     = ne00;
 614        const auto task_count      = ne01 * ne02 * ne03;
 615        const auto task_per_thread = (task_count + nth - 1) / nth;
 616
 617        const auto task_begin = ith * task_per_thread;
 618        const auto task_end   = std::min((ith + 1) * task_per_thread, task_count);
 619
 620        for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
 621            auto   offset  = task_idx * hidden_size;
 622            auto * p_input = const_cast<float *>(input + offset);
 623
 624            auto *       p_output      = output + offset;
 625            auto *       p_temp_output = p_output;
 626            auto *       p_gamma_data  = (const float *) nullptr;
 627            auto *       p_beta_data   = (const float *) nullptr;
 628            size_t       gvl           = __riscv_vsetvlmax_e32m4();
 629            vfloat32m4_t sum           = __riscv_vfmv_v_f_f32m4(0.f, gvl);
 630            vfloat32m4_t sum_sq        = __riscv_vfmv_v_f_f32m4(0.f, gvl);
 631            int64_t      length        = hidden_size;
 632            while (length > 0) {
 633                gvl                   = __riscv_vsetvl_e32m4(length);
 634                // load data
 635                vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
 636
 637                sum    = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);
 638                sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
 639
 640                __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
 641
 642                p_input += gvl;
 643                p_temp_output += gvl;
 644                length -= gvl;
 645            }
 646
 647            gvl = __riscv_vsetvlmax_e32m1();
 648
 649            float        mean   = 0.f;
 650            vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
 651            vfloat32m1_t mean_v =
 652                __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
 653            mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
 654            mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
 655            mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
 656            mean   = __riscv_vfmv_f_s_f32m1_f32(mean_v);
 657            mean /= hidden_size;
 658
 659            vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
 660                                                                __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
 661            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
 662            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
 663            mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
 664
 665            float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
 666            mean_square /= hidden_size;
 667            mean_square = sqrt(mean_square - mean * mean + epsilon);
 668
 669            mean_square   = 1.0f / mean_square;
 670            length        = hidden_size;
 671            p_temp_output = p_output;
 672
 673            if (p_gamma_data == nullptr && p_beta_data == nullptr) {
 674                while (length > 0) {
 675                    gvl                   = __riscv_vsetvl_e32m4(length);
 676                    vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
 677                    src_data              = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
 678                    src_data              = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
 679                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
 680                    p_temp_output += gvl;
 681                    p_output += gvl;
 682                    length -= gvl;
 683                }
 684            } else if (p_beta_data == nullptr) {
 685                while (length > 0) {
 686                    gvl                       = __riscv_vsetvl_e32m4(length);
 687                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
 688                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
 689                    src_data                  = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
 690                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
 691                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
 692                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
 693                    p_temp_output += gvl;
 694                    p_output += gvl;
 695                    p_gamma_data += gvl;
 696                    length -= gvl;
 697                }
 698            } else if (p_gamma_data != nullptr) {
 699                while (length > 0) {
 700                    gvl                       = __riscv_vsetvl_e32m4(length);
 701                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
 702                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
 703                    src_data                  = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
 704                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
 705                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
 706                    vfloat32m4_t beta_data_v  = __riscv_vle32_v_f32m4(p_beta_data, gvl);
 707                    src_data                  = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
 708                    p_beta_data += gvl;
 709                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
 710                    p_temp_output += gvl;
 711                    p_output += gvl;
 712                    p_gamma_data += gvl;
 713                    length -= gvl;
 714                }
 715            }
 716        }
 717    }
 718
 719    void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
 720        const ggml_tensor * src0 = op->src[0];
 721        ggml_tensor *       dst  = op;
 722        GGML_ASSERT(ggml_are_same_shape(src0, dst));
 723        GGML_ASSERT(src0->nb[0] == sizeof(float));
 724
 725        const int ith = params->ith;
 726        const int nth = params->nth;
 727
 728        GGML_TENSOR_UNARY_OP_LOCALS
 729
 730        float epsilon;
 731        memcpy(&epsilon, dst->op_params, sizeof(float));
 732
 733        GGML_ASSERT(epsilon > 0.0f);
 734
 735        auto * input  = (float *) src0->data;
 736        auto * output = (float *) dst->data;
 737
 738        const auto hidden_size     = ne00;
 739        const auto task_count      = ne01 * ne02 * ne03;
 740        const auto task_per_thread = (task_count + nth - 1) / nth;
 741
 742        const auto task_begin = ith * task_per_thread;
 743        const auto task_end   = std::min((ith + 1) * task_per_thread, task_count);
 744
 745        for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
 746            auto   offset        = task_idx * hidden_size;
 747            auto * p_input       = const_cast<float *>(input + offset);
 748            auto * p_output      = output + offset;
 749            auto * p_temp_output = p_output;
 750            auto * p_gamma_data  = (const float *) nullptr;
 751            auto * p_beta_data   = (const float *) nullptr;
 752
 753            size_t       gvl    = __riscv_vsetvlmax_e32m4();
 754            // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
 755            vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
 756            int64_t      length = hidden_size;
 757            while (length > 0) {
 758                gvl                   = __riscv_vsetvl_e32m4(length);
 759                // load data
 760                vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
 761
 762                sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
 763
 764                __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
 765
 766                p_input += gvl;
 767                p_temp_output += gvl;
 768                length -= gvl;
 769            }
 770
 771            gvl = __riscv_vsetvlmax_e32m1();
 772
 773            // float mean = 0.f;
 774            vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
 775
 776            vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
 777                                                                __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
 778            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
 779            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
 780            mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
 781
 782            float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
 783            mean_square /= hidden_size;
 784
 785            mean_square = sqrt(mean_square + epsilon);
 786
 787            mean_square   = 1.0f / mean_square;
 788            length        = hidden_size;
 789            p_temp_output = p_output;
 790
 791            if (p_gamma_data == nullptr && p_beta_data == nullptr) {
 792                while (length > 0) {
 793                    gvl                   = __riscv_vsetvl_e32m4(length);
 794                    vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
 795                    src_data              = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
 796                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
 797                    p_temp_output += gvl;
 798                    p_output += gvl;
 799                    length -= gvl;
 800                }
 801            } else if (p_beta_data == nullptr) {
 802                while (length > 0) {
 803                    gvl                       = __riscv_vsetvl_e32m4(length);
 804                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
 805                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
 806                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
 807                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
 808                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
 809                    p_temp_output += gvl;
 810                    p_output += gvl;
 811                    p_gamma_data += gvl;
 812                    length -= gvl;
 813                }
 814            } else if (p_gamma_data != nullptr) {
 815                while (length > 0) {
 816                    gvl                       = __riscv_vsetvl_e32m4(length);
 817                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
 818                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
 819                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
 820                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
 821                    vfloat32m4_t beta_data_v  = __riscv_vle32_v_f32m4(p_beta_data, gvl);
 822                    src_data                  = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
 823                    p_beta_data += gvl;
 824                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
 825                    p_temp_output += gvl;
 826                    p_output += gvl;
 827                    p_gamma_data += gvl;
 828                    length -= gvl;
 829                }
 830            }
 831        }
 832    }
 833
 834    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
 835        memcpy(t->data, data, data_size);
 836        return 0;
 837    }
 838};
 839
 840static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
 841static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;
 842static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;
 843static const tensor_traits_common             rvv_impl;
 844
 845}  // namespace ggml::cpu::riscv64_spacemit
 846
 847static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {
 848    if (cur->type == GGML_TYPE_Q4_0) {
 849        if (cur->ne[1] % 16 == 0) {
 850            return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
 851        }
 852    } else if (cur->type == GGML_TYPE_Q4_1) {
 853        if (cur->ne[1] % 16 == 0) {
 854            return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;
 855        }
 856    } else if (cur->type == GGML_TYPE_Q4_K) {
 857        if (cur->ne[1] % 16 == 0) {
 858            return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;
 859        }
 860    } else if (cur->type == GGML_TYPE_F32) {
 861        return &ggml::cpu::riscv64_spacemit::rvv_impl;
 862    }
 863
 864    return nullptr;
 865}
 866
 867static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
 868                                                                         struct ggml_tensor *  tensor) {
 869    tensor->extra =
 870        (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
 871
 872    GGML_UNUSED(buffer);
 873
 874    return GGML_STATUS_SUCCESS;
 875}
 876
 877static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
 878                                                            struct ggml_tensor *  tensor,
 879                                                            const void *          data,
 880                                                            size_t                offset,
 881                                                            size_t                size) {
 882    GGML_ASSERT(offset == 0);
 883    GGML_ASSERT(size == ggml_nbytes(tensor));
 884
 885    auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra;
 886    if (tensor_traits) {
 887        auto OK = tensor_traits->repack(tensor, data, size);
 888        GGML_ASSERT(OK == 0);
 889    }
 890
 891    GGML_UNUSED(buffer);
 892}
 893
 894static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
 895    return "CPU_RISCV64_SPACEMIT";
 896
 897    GGML_UNUSED(buft);
 898}
 899
 900static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
 901                                                                                        size_t size) {
 902    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
 903
 904    if (buffer == nullptr) {
 905        return nullptr;
 906    }
 907
 908    buffer->buft              = buft;
 909    buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
 910    buffer->iface.set_tensor  = ggml_backend_riscv64_spacemit_buffer_set_tensor;
 911    buffer->iface.get_tensor  = nullptr;
 912    buffer->iface.cpy_tensor  = nullptr;
 913    return buffer;
 914}
 915
 916static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
 917    return 64;
 918
 919    GGML_UNUSED(buft);
 920}
 921
 922static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
 923                                                       const struct ggml_tensor * tensor) {
 924    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
 925        if (tensor->ne[i] <= 0) {
 926            return 0;
 927        }
 928    }
 929
 930    size_t       nbytes;
 931    const size_t blck_size = ggml_blck_size(tensor->type);
 932    if (blck_size == 1) {
 933        nbytes = ggml_type_size(tensor->type);
 934        for (int i = 0; i < GGML_MAX_DIMS; ++i) {
 935            nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
 936        }
 937    } else {
 938        nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
 939        if (tensor->type == GGML_TYPE_Q4_K) {
 940            GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);
 941            nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
 942            for (int i = 1; i < GGML_MAX_DIMS; ++i) {
 943                nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
 944            }
 945        } else {
 946            for (int i = 1; i < GGML_MAX_DIMS; ++i) {
 947                nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
 948            }
 949        }
 950    }
 951
 952    GGML_UNUSED(buft);
 953    return nbytes;
 954}
 955
 956namespace ggml::cpu::riscv64_spacemit {
 957
 958class extra_buffer_type : ggml::cpu::extra_buffer_type {
 959    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
 960        switch (op->op) {
 961            case GGML_OP_MUL_MAT:
 962                if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
 963                    op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
 964                    ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
 965                    if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
 966                        return false;
 967                    }
 968                    if (op->src[1]->type == GGML_TYPE_F32) {
 969                        return true;
 970                    }
 971                }
 972                break;
 973            case GGML_OP_NORM:
 974            case GGML_OP_RMS_NORM:
 975                if (op->src[0]->type == GGML_TYPE_F32) {
 976                    return true;
 977                }
 978                break;
 979            default:
 980                // GGML_ABORT("fatal error");
 981                break;
 982        }
 983        return false;
 984    }
 985
 986    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
 987        switch (op->op) {
 988            case GGML_OP_MUL_MAT:
 989                if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
 990                    return (ggml::cpu::tensor_traits *) op->src[0]->extra;
 991                }
 992                break;
 993            case GGML_OP_NORM:
 994            case GGML_OP_RMS_NORM:
 995                return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
 996            default:
 997                // GGML_ABORT("fatal error");
 998                break;
 999        }
1000
1001        return nullptr;
1002    }
1003};
1004
1005}  // namespace ggml::cpu::riscv64_spacemit
1006
1007ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
1008    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
1009  /* .iface    = */
1010        {
1011         /* .get_name         = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
1012         /* .alloc_buffer     = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer,
1013         /* .get_alignment    = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment,
1014         /* .get_max_size     = */ nullptr,
1015         /* .get_alloc_size   = */ ggml_backend_cpu_riscv64_spacemit_nbytes,
1016         /* .is_host          = */ nullptr,
1017         },
1018 /* .device  = */
1019        ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
1020 /* .context = */
1021        new ggml::cpu::riscv64_spacemit::extra_buffer_type(),
1022    };
1023
1024    return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
1025}