1// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
2// SPDX-License-Identifier: MIT
3//
4
5// KleidiAI micro-kernels
6#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
7#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
8#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
9#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
10#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
11#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
12#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
13#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
14#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
15#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
16#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
17#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
18#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
19#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
20#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
21#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
22#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
23
24#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
25#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
26#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
27#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
28#include "kai_lhs_quant_pack_qai8dxp_f32.h"
29
30#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
31#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
32#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
33#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
34
35#include "kai_common.h"
36
37#include "simd-mappings.h"
38
39#define GGML_COMMON_DECL_CPP
40#include "ggml-common.h"
41
42#include "kernels.h"
43
44#define NELEMS(x) (sizeof(x) / sizeof(*x))
45
46template<size_t(*Fn)(size_t,size_t,size_t)>
47static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
48 return Fn(a, b, c);
49}
50
51template<size_t(*Fn)(size_t,size_t)>
52static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {
53 return Fn(a, b);
54}
55
56template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
57static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,
58 const void* lhs, const void* rhs, void* dst,
59 size_t dst_stride_row, size_t dst_stride_col,
60 float clamp_min, float clamp_max) {
61 Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
62}
63
64template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>
65static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
66 const void* lhs, const void* rhs, void* dst,
67 size_t dst_stride_row, size_t dst_stride_col,
68 float clamp_min, float clamp_max) {
69 Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
70}
71
72template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
73static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
74 const void* lhs, const void* rhs, void* dst,
75 size_t dst_stride_row, size_t dst_stride_col,
76 float clamp_min, float clamp_max) {
77 Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
78}
79
80template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
81static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
82 return Fn(m, k, bl, mr, kr, sr);
83}
84
85template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
86static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
87 return Fn(m, k, mr, kr, sr);
88}
89
90template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
91static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
92 return Fn(m_idx, k, bl, mr, kr, sr);
93}
94
95template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
96static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
97 return Fn(m_idx, k, mr, kr, sr);
98}
99
100template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
101static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
102 size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
103 Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
104}
105
106template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
107static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
108 size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
109 Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
110}
111
112template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
113static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
114 size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
115 Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
116}
117
118template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
119static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
120 size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
121 Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
122}
123
124template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
125static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
126 return Fn(n, k, nr, kr, bl);
127}
128
129template<size_t(*Fn)(size_t,size_t)>
130static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
131 return Fn(n, k);
132}
133
134template<size_t(*Fn)(size_t,size_t,size_t,size_t)>
135static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {
136 return Fn(k, nr, kr, bl);
137}
138
139template<size_t(*Fn)(size_t)>
140static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
141 return Fn(k);
142}
143
144template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>
145static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
146 size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,
147 void* rhs_packed, size_t extra_bytes, const void* params) {
148 Fn(num_groups, n, k, nr, kr, sr, bl,
149 static_cast<const uint8_t*>(rhs),
150 static_cast<const float*>(bias),
151 rhs_packed, extra_bytes,
152 static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
153}
154
155template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
156static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
157 size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
158 void* rhs_packed, size_t extra_bytes, const void* params) {
159 Fn(num_groups, n, k, nr, kr, sr,
160 static_cast<const int8_t*>(rhs),
161 static_cast<const float*>(bias),
162 static_cast<const float*>(scale),
163 rhs_packed, extra_bytes,
164 static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
165}
166
167template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
168static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
169 size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
170 void* rhs_packed, size_t extra_bytes, const void* params) {
171 Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);
172}
173
174static const size_t INT4_PER_BYTE = 2;
175static const size_t INT4_BITS = 4;
176static const int Q4_0_ZERO_POINT = 8;
177const size_t INT4_PER_UINT16 = 4;
178
179static void dequantize_row_qsi4c32pscalef16(
180 const void *packed_data,
181 int32_t row_idx,
182 int64_t nc,
183 float *out,
184 size_t nr_pack,
185 size_t packed_row_stride,
186 size_t kr,
187 size_t bl,
188 size_t num_bytes_multiplier
189) {
190 size_t group_idx = row_idx / nr_pack;
191 size_t row_in_group = row_idx % nr_pack;
192 const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
193 size_t num_blocks = nc / bl;
194 const uint8_t *block_ptr = packed_group;
195
196 for (size_t b = 0; b < num_blocks; ++b) {
197 uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
198 float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
199
200 const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
201 size_t num_segments = bl / kr;
202 size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
203
204 for (size_t s = 0; s < num_segments; ++s) {
205 const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
206 const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
207 for (size_t k = 0; k < num_bytes_per_segment; ++k) {
208 uint8_t byte = qbytes[k] ^ 0x88;
209 int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
210 int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
211 out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
212 out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
213 }
214 }
215 block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
216 }
217}
218
219static void dequantize_row_qsi4c32ps1s0scalef16(
220 const void *packed_data,
221 int32_t row_idx,
222 int64_t k,
223 float *out,
224 size_t nr,
225 size_t packed_row_stride,
226 size_t kr,
227 size_t bl,
228 size_t num_bytes_multiplier
229) {
230 const size_t num_blocks = k / bl;
231 const size_t bl4 = bl / INT4_PER_UINT16;
232
233 size_t group_idx = row_idx / nr;
234 size_t row_in_group = row_idx % nr;
235
236 const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
237 const uint16_t *qdata = (const uint16_t *)packed_group;
238 const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
239
240 for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
241 uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
242 float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
243
244 for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
245 uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
246
247 for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
248 int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
249 out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
250 }
251 }
252 }
253 GGML_UNUSED(kr);
254}
255
256static void dequantize_row_qsi8cxp(
257 const void *packed_data,
258 int32_t row_idx,
259 int64_t k,
260 float *out,
261 size_t nr,
262 size_t packed_row_stride,
263 size_t kr,
264 size_t bl,
265 size_t num_bytes_multiplier
266) {
267 GGML_UNUSED(bl);
268 GGML_UNUSED(num_bytes_multiplier);
269
270 const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
271 const size_t group_idx = row_idx / nr;
272 const size_t row_in_group = row_idx % nr;
273
274 const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
275 const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
276
277 const size_t num_blocks = k_internal / kr;
278
279 for (size_t block = 0; block < num_blocks; ++block) {
280 const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
281 for (size_t i = 0; i < kr; ++i) {
282 const size_t k_idx = block * kr + i;
283 if (k_idx < (size_t) k) {
284 out[k_idx] = static_cast<float>(block_ptr[i]);
285 }
286 }
287 }
288
289 const uint8_t * sums_ptr = group_ptr + nr * k_internal;
290 GGML_UNUSED(sums_ptr);
291
292 const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
293 const float scale = scale_ptr[row_in_group];
294
295 if (scale == 0.0f) {
296 for (size_t i = 0; i < (size_t) k; ++i) {
297 out[i] = 0.0f;
298 }
299 return;
300 }
301
302 for (size_t i = 0; i < (size_t) k; ++i) {
303 out[i] *= scale;
304 }
305}
306
307static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
308#if defined(__ARM_FEATURE_SME)
309 {
310 /* SME GEMM */
311 /* .kern_info = */ {
312 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
313 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
314 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
315 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
316 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
317 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
318 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
319 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
320 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
321 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
322 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
323 },
324
325 /* .gemm_lhs_info = */ {
326 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
327 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
328 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
329 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
330 },
331 /* SME GEMV */
332 /* .kern_info = */ {
333 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
334 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
335 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
336 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
337 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
338 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
339 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
340 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
341 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
342 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
343 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
344 },
345 /* .gemv_lhs_info = */ {
346 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
347 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
348 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
349 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
350 },
351 /* .rhs_info = */ {
352 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
353 /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
354 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
355 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
356 /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
357 },
358 /* .required_cpu = */ CPU_FEATURE_SME,
359 /* .lhs_type = */ GGML_TYPE_F32,
360 /* .rhs_type = */ GGML_TYPE_Q4_0,
361 /* .op_type = */ GGML_TYPE_F32,
362 },
363 {
364 /* SME GEMM */
365 /* .kern_info = */ {
366 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
367 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
368 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
369 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
370 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
371 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
372 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
373 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
374 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
375 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
376 /* .run_kernel_ex = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
377 },
378 /* .gemm_lhs_info = */ {
379 /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
380 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
381 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
382 /* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
383 },
384 /* SME GEMV */
385 /* .kern_info = */ {
386 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
387 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
388 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
389 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
390 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
391 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
392 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
393 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
394 /* .get_lhs_offset_ex = */ nullptr,
395 /* .get_rhs_packed_offset_ex = */ nullptr,
396 /* .run_kernel_ex = */ nullptr,
397 },
398 /* .gemv_lhs_info = */ {
399 /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
400 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
401 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
402 /* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
403 },
404 /* .rhs_info = */ {
405 /* .packed_stride = */ nullptr,
406 /* .to_float = */ nullptr,
407 /* .packed_size_ex = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
408 /* .packed_stride_ex = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
409 /* .pack_func_ex = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
410 },
411 /* .required_cpu = */ CPU_FEATURE_SME,
412 /* .lhs_type = */ GGML_TYPE_F32,
413 /* .rhs_type = */ GGML_TYPE_F16,
414 /* .op_type = */ GGML_TYPE_F32,
415 },
416#endif
417#if defined(__APPLE__)
418#if defined(__ARM_FEATURE_DOTPROD)
419 {
420 /* DOTPROD GEMM */
421 /* .kern_info = */ {
422 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
423 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
424 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
425 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
426 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
427 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
428 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
429 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
430 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
431 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
432 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
433 },
434 /* .gemm_lhs_info = */ {
435 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
436 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
437 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
438 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
439 },
440 /* DOTPROD GEMV */
441 /* .kern_info = */ {
442 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
443 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
444 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
445 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
446 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
447 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
448 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
449 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
450 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
451 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
452 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
453 },
454 /* .gemv_lhs_info = */ {
455 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
456 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
457 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
458 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
459 },
460 /* .rhs_info = */ {
461 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
462 /* .to_float = */ dequantize_row_qsi4c32pscalef16,
463 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
464 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
465 /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
466 },
467 /* .required_cpu = */ CPU_FEATURE_DOTPROD,
468 /* .lhs_type = */ GGML_TYPE_F32,
469 /* .rhs_type = */ GGML_TYPE_Q4_0,
470 /* .op_type = */ GGML_TYPE_F32,
471 },
472#endif
473#if defined(__ARM_FEATURE_MATMUL_INT8)
474 {
475 /* i8mm GEMM */
476 /* .kern_info = */ {
477 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
478 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
479 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
480 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
481 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
482 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
483 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
484 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
485 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
486 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
487 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
488 },
489 /* .gemm_lhs_info = */ {
490 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
491 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
492 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
493 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
494 },
495 /* i8mm GEMV */
496 /* .kern_info = */ {
497 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
498 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
499 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
500 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
501 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
502 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
503 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
504 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
505 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
506 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
507 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
508 },
509 /* .gemv_lhs_info = */ {
510 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
511 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
512 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
513 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
514 },
515 /* .rhs_info = */ {
516 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
517 /* .to_float = */ dequantize_row_qsi4c32pscalef16,
518 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
519 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
520 /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
521 },
522 /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
523 /* .lhs_type = */ GGML_TYPE_F32,
524 /* .rhs_type = */ GGML_TYPE_Q4_0,
525 /* .op_type = */ GGML_TYPE_F32,
526 },
527#endif
528#else
529#if defined(__ARM_FEATURE_SVE)
530 {
531 /* SVE i8mm GEMM */
532 /* .kern_info = */ {
533 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
534 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
535 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
536 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
537 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
538 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
539 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
540 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
541 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
542 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
543 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
544 },
545 /* .gemm_lhs_info = */ {
546 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
547 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
548 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
549 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
550 },
551 /* SVE dotprod GEMV */
552 /* .kern_info = */ {
553 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
554 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
555 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
556 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
557 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
558 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
559 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
560 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
561 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
562 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
563 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
564 },
565 /* .gemv_lhs_info = */ {
566 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
567 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
568 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
569 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
570 },
571 /* .rhs_info = */ {
572 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
573 /* .to_float = */ dequantize_row_qsi4c32pscalef16,
574 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
575 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
576 /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
577 },
578 /* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD,
579 /* .lhs_type = */ GGML_TYPE_F32,
580 /* .rhs_type = */ GGML_TYPE_Q4_0,
581 /* .op_type = */ GGML_TYPE_F32,
582 },
583#endif
584#if defined(__ARM_FEATURE_MATMUL_INT8)
585 {
586 /* i8mm GEMM */
587 /* .kern_info = */ {
588 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
589 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
590 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
591 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
592 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
593 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
594 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
595 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
596 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
597 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
598 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
599 },
600 /* .gemm_lhs_info = */ {
601 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
602 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
603 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
604 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
605 },
606 /* i8mm GEMV */
607 /* .kern_info = */ {
608 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
609 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
610 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
611 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
612 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
613 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
614 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
615 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
616 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
617 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
618 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
619 },
620 /* .gemv_lhs_info = */ {
621 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
622 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
623 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
624 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
625 },
626 /* .rhs_info = */ {
627 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
628 /* .to_float = */ dequantize_row_qsi4c32pscalef16,
629 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
630 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
631 /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
632 },
633 /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
634 /* .lhs_type = */ GGML_TYPE_F32,
635 /* .rhs_type = */ GGML_TYPE_Q4_0,
636 /* .op_type = */ GGML_TYPE_F32,
637 },
638#endif // __ARM_FEATURE_MATMUL_INT8
639#if defined(__ARM_FEATURE_DOTPROD)
640 {
641 /* DOTPROD GEMM */
642 /* .kern_info = */ {
643 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
644 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
645 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
646 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
647 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
648 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
649 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
650 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
651 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
652 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
653 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
654 },
655 /* .gemm_lhs_info = */ {
656 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
657 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
658 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
659 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
660 },
661 /* DOTPROD GEMV */
662 /* .kern_info = */ {
663 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
664 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
665 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
666 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
667 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
668 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
669 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
670 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
671 /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
672 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
673 /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
674 },
675 /* .gemv_lhs_info = */ {
676 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
677 /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
678 /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
679 /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
680 },
681 /* .rhs_info = */ {
682 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
683 /* .to_float = */ dequantize_row_qsi4c32pscalef16,
684 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
685 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
686 /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
687 },
688 /* .required_cpu = */ CPU_FEATURE_DOTPROD,
689 /* .lhs_type = */ GGML_TYPE_F32,
690 /* .rhs_type = */ GGML_TYPE_Q4_0,
691 /* .op_type = */ GGML_TYPE_F32,
692 },
693#endif
694#endif
695 { /* Sentinel */ }
696};
697
698static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
699#if defined(__ARM_FEATURE_SME)
700 {
701 /* SME GEMM */
702 {
703 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
704 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
705 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
706 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
707 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
708 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
709 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
710 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
711 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
712 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
713 /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
714 },
715 /* .gemm_lhs_info = */ {
716 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
717 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
718 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
719 /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
720 },
721 /* SME GEMV */
722 {
723 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
724 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
725 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
726 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
727 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
728 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
729 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
730 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
731 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
732 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
733 /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
734 },
735 /* .gemv_lhs_info = */ {
736 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
737 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
738 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
739 /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
740 },
741 /* .rhs_info = */ {
742 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
743 /* .to_float = */ dequantize_row_qsi8cxp,
744 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
745 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
746 /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
747 },
748 /* .required_cpu = */ CPU_FEATURE_SME,
749 /* .lhs_type = */ GGML_TYPE_F32,
750 /* .rhs_type = */ GGML_TYPE_Q8_0,
751 /* .op_type = */ GGML_TYPE_F32,
752 },
753#endif
754#if defined(__ARM_FEATURE_MATMUL_INT8)
755 {
756 /* I8MM GEMM */
757 {
758 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
759 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
760 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
761 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
762 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
763 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
764 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
765 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
766 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
767 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
768 /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
769 },
770 /* .gemm_lhs_info = */ {
771 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
772 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
773 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
774 /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
775 },
776 /* I8MM GEMV (dotprod fallback) */
777 {
778 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
779 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
780 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
781 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
782 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
783 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
784 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
785 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
786 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
787 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
788 /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
789 },
790 /* .gemv_lhs_info = */ {
791 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
792 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
793 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
794 /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
795 },
796 /* .rhs_info = */ {
797 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
798 /* .to_float = */ dequantize_row_qsi8cxp,
799 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
800 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
801 /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
802 },
803 /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
804 /* .lhs_type = */ GGML_TYPE_F32,
805 /* .rhs_type = */ GGML_TYPE_Q8_0,
806 /* .op_type = */ GGML_TYPE_F32,
807 },
808#endif
809#if defined(__ARM_FEATURE_DOTPROD)
810 {
811 /* DOTPROD GEMM */
812 {
813 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
814 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
815 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
816 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
817 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
818 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
819 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
820 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
821 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
822 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
823 /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
824 },
825 /* .gemm_lhs_info = */ {
826 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
827 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
828 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
829 /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
830 },
831 /* DOTPROD GEMV */
832 {
833 /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
834 /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
835 /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
836 /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
837 /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
838 /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
839 /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
840 /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
841 /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
842 /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
843 /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
844 },
845 /* .gemv_lhs_info = */ {
846 /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
847 /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
848 /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
849 /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
850 },
851 /* .rhs_info = */ {
852 /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
853 /* .to_float = */ dequantize_row_qsi8cxp,
854 /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
855 /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
856 /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
857 },
858 /* .required_cpu = */ CPU_FEATURE_DOTPROD,
859 /* .lhs_type = */ GGML_TYPE_F32,
860 /* .rhs_type = */ GGML_TYPE_Q8_0,
861 /* .op_type = */ GGML_TYPE_F32,
862 },
863#endif
864 { /* Sentinel */ }
865};
866
867ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
868 ggml_kleidiai_kernels * kernel = nullptr;
869
870 if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
871#if defined(__ARM_FEATURE_SME) || \
872 defined(__ARM_FEATURE_DOTPROD) || \
873 defined(__ARM_FEATURE_MATMUL_INT8) || \
874 defined(__ARM_FEATURE_SVE)
875 auto try_table = [&](auto & table) {
876 for (size_t i = 0; i < NELEMS(table) - 1; ++i) {
877 if ((cpu_features & table[i].required_cpu) == table[i].required_cpu &&
878 table[i].lhs_type == tensor->src[1]->type &&
879 table[i].rhs_type == tensor->src[0]->type &&
880 table[i].op_type == tensor->type) {
881 kernel = &table[i];
882 return true;
883 }
884 }
885 return false;
886 };
887
888 if (tensor->src[0]->type == GGML_TYPE_Q8_0) {
889 try_table(gemm_gemv_kernels_q8);
890 } else {
891 try_table(gemm_gemv_kernels);
892 }
893#else
894 GGML_UNUSED(gemm_gemv_kernels);
895 GGML_UNUSED(gemm_gemv_kernels_q8);
896 GGML_UNUSED(cpu_features);
897#endif
898 }
899
900 return kernel;
901}
902
903ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
904 ggml_kleidiai_kernels * kernels = nullptr;
905
906#if defined(__ARM_FEATURE_SME) || \
907 defined(__ARM_FEATURE_DOTPROD) || \
908 defined(__ARM_FEATURE_MATMUL_INT8) || \
909 defined(__ARM_FEATURE_SVE)
910 for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
911 if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
912 kernels = &gemm_gemv_kernels[i];
913 break;
914 }
915 }
916#else
917 GGML_UNUSED(features);
918#endif
919
920 return kernels;
921}
922
923ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
924 ggml_kleidiai_kernels * kernels = nullptr;
925
926#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
927 for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
928 if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
929 kernels = &gemm_gemv_kernels_q8[i];
930 break;
931 }
932 }
933#else
934 GGML_UNUSED(features);
935#endif
936
937 return kernels;
938}