1// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
 2// SPDX-License-Identifier: MIT
 3//
 4
 5#pragma once
 6
 7#include "ggml.h"
 8
 9enum cpu_feature {
10    CPU_FEATURE_NONE    = 0,
11    CPU_FEATURE_DOTPROD = 1,
12    CPU_FEATURE_I8MM    = 2,
13    CPU_FEATURE_SVE     = 4,
14    CPU_FEATURE_SME     = 8
15};
16
17inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
18    lhs = static_cast<cpu_feature>(lhs | rhs);
19    return lhs;
20}
21inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
22    return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
23}
24
25struct kernel_info {
26    size_t (*get_m_step)(void);
27    size_t (*get_n_step)(void);
28    size_t (*get_mr)(void);
29    size_t (*get_nr)(void);
30    size_t (*get_kr)(void);
31    size_t (*get_sr)(void);
32
33    size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
34    size_t (*get_dst_size)(size_t m, size_t n);
35
36    size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
37
38    size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
39
40    void (*run_kernel_ex)(
41        size_t m, size_t n, size_t k, size_t bl,
42        const void* lhs_packed, const void* rhs_packed,
43        void* dst, size_t dst_stride_row, size_t dst_stride_col,
44        float clamp_min, float clamp_max);
45};
46
47struct lhs_packing_info {
48    size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
49
50    size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
51
52    size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
53
54    void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
55        size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
56};
57
58struct rhs_packing_info {
59    size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
60
61    void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
62                     size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
63                     size_t num_bytes_multiplier);
64
65    size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
66
67    size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
68
69    void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
70        size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
71};
72
73struct ggml_kleidiai_kernels {
74    kernel_info      gemm;
75    lhs_packing_info gemm_lhs_info;
76
77    kernel_info      gemv;
78    lhs_packing_info gemv_lhs_info;
79
80    rhs_packing_info rhs_info;
81
82    cpu_feature required_cpu;
83    ggml_type lhs_type;
84    ggml_type rhs_type;
85    ggml_type op_type;
86};
87
88ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
89ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
90ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);