1// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
2// SPDX-License-Identifier: MIT
3//
4#include <arm_neon.h>
5#include <assert.h>
6#include <atomic>
7#include <cfloat>
8#include <cmath>
9#include <algorithm>
10#include <stdexcept>
11#include <stdint.h>
12#include <string.h>
13#include <string>
14#include <vector>
15#if defined(__linux__)
16#include <asm/hwcap.h>
17#include <sys/auxv.h>
18#elif defined(__APPLE__)
19#include <string_view>
20#include <sys/sysctl.h>
21#include <sys/types.h>
22#elif defined(_WIN32)
23#include <windows.h>
24#include <excpt.h>
25#endif
26
27#include "kleidiai.h"
28
29#include "ggml-cpu.h"
30#include "ggml-impl.h"
31#include "ggml-backend-impl.h"
32#include "ggml-threading.h"
33#include "traits.h"
34
35#include "kernels.h"
36
37#include "kai_common.h"
38
39#define GGML_COMMON_DECL_CPP
40#include "ggml-common.h"
41
42struct ggml_kleidiai_context {
43 cpu_feature features;
44 ggml_kleidiai_kernels * kernels_q4;
45 ggml_kleidiai_kernels * kernels_q8;
46} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
47
48static const char* cpu_feature_to_string(cpu_feature f) {
49 if (f == CPU_FEATURE_NONE) {
50 return "NONE";
51 } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
52 return "SME";
53 } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
54 return "SVE";
55 }
56 else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
57 return "I8MM";
58 } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
59 return "DOTPROD";
60 }
61 else {
62 return "UNKNOWN";
63 }
64}
65
66static void init_kleidiai_context(void) {
67
68 ggml_critical_section_start();
69 static bool initialized = false;
70
71 if (!initialized) {
72 initialized = true;
73 const char *env_var = getenv("GGML_KLEIDIAI_SME");
74 int sme_enabled = 0;
75
76 ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
77 (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
78 ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
79
80 if (env_var) {
81 sme_enabled = atoi(env_var);
82 }
83
84 if (sme_enabled != 0) {
85 ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
86 }
87 ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
88 ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
89#ifndef NDEBUG
90 if (ctx.kernels_q4) {
91 GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
92 }
93 if (ctx.kernels_q8) {
94 GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
95 }
96#endif
97 }
98 ggml_critical_section_end();
99}
100
101static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
102 GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
103 return tensor->ne[dim];
104}
105
106namespace ggml::cpu::kleidiai {
107
108static size_t round_down(size_t x, size_t y) {
109 return y == 0 ? x : x - (x % y);
110}
111
112static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
113 size_t src_stride = rhs_stride / sizeof(uint16_t);
114 size_t dst_stride = n;
115
116 for (size_t k_idx = 0; k_idx < k; ++k_idx) {
117 for (size_t n_idx = 0; n_idx < n; ++n_idx) {
118 uint16_t v = *(src + k_idx + n_idx * src_stride);
119 *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
120 }
121 }
122}
123
124class tensor_traits : public ggml::cpu::tensor_traits {
125 bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
126 if (op->op != GGML_OP_MUL_MAT) {
127 return false;
128 }
129 ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
130 if (!kernels) {
131 return false;
132 }
133 bool is_gemv = op->src[1]->ne[1] == 1;
134 kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
135 lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
136
137 size_t k = op->src[0]->ne[0];
138 size_t n = op->src[0]->ne[1];
139 size_t m = op->src[1]->ne[1];
140
141 size_t mr = kernel->get_mr();
142 size_t kr = kernel->get_kr();
143 size_t sr = kernel->get_sr();
144
145 if (kernels->rhs_type == GGML_TYPE_Q4_0) {
146 if (!lhs_info->packed_size_ex) return false;
147 size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
148 } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
149 if (!lhs_info->packed_size_ex) return false;
150 size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
151 } else if (kernels->rhs_type == GGML_TYPE_F16) {
152 if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
153 const int64_t lhs_batch_size0 = op->src[1]->ne[2];
154 const int64_t rhs_batch_size0 = op->src[0]->ne[2];
155 const int64_t r = lhs_batch_size0 / rhs_batch_size0;
156 size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
157 kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
158 k * n * sizeof(float) + n * sizeof(float);
159 } else {
160 return false;
161 }
162
163 return true;
164 }
165
166 bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
167 if (dst->op == GGML_OP_MUL_MAT) {
168 if (dst->src[0]->type == GGML_TYPE_Q4_0) {
169 return compute_forward_q4_0(params, dst);
170 } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
171 return compute_forward_q8_0(params, dst);
172 } else if (dst->src[0]->type == GGML_TYPE_F16) {
173 return compute_forward_fp16(params, dst);
174 }
175 } else if (dst->op == GGML_OP_GET_ROWS) {
176 if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
177 return compute_forward_get_rows(params, dst);
178 }
179 }
180 return false;
181 }
182
183 bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
184 const ggml_tensor * src0 = dst->src[0];
185 const ggml_tensor * src1 = dst->src[1];
186
187 GGML_TENSOR_BINARY_OP_LOCALS
188
189 ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
190 if (!kernels) {
191 return false;
192 }
193
194 const bool is_gemv = src1->ne[1] == 1;
195 kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
196 lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
197 GGML_ASSERT(kernel);
198 if (!kernels->rhs_info.pack_func_ex ||
199 !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
200 return false;
201 }
202
203 const int nth = params->nth;
204 const int ith = params->ith;
205
206 const int64_t lhs_batch_size0 = ne12;
207 const int64_t rhs_batch_size0 = ne02;
208 const int64_t batch_size = lhs_batch_size0;
209
210 GGML_ASSERT(rhs_batch_size0 > 0);
211 GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
212 const int64_t r = lhs_batch_size0 / rhs_batch_size0;
213
214 const int64_t m_group = ne11;
215 const int64_t m = m_group;
216 const int64_t n = ne01;
217 const int64_t k = ne00;
218
219 const size_t lhs_stride = src1->nb[1];
220 const size_t rhs_stride = src0->nb[1];
221 const size_t dst_stride = dst->nb[1];
222
223 const int64_t mr = (int64_t) kernel->get_mr();
224 const int64_t nr = (int64_t) kernel->get_nr();
225 const int64_t kr = (int64_t) kernel->get_kr();
226 const int64_t sr = (int64_t) kernel->get_sr();
227
228 const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
229 const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
230 const size_t kxn_size = k * n * sizeof(float);
231 const size_t bias_size = n * sizeof(float);
232
233 const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
234 GGML_ASSERT(wsize_required <= params->wsize);
235
236 uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
237 uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
238 uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
239 uint8_t * bias = rhs_kxn + kxn_size;
240
241 for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
242 const int64_t rhs_batch_idx = batch_idx / r;
243 const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
244 uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
245
246 // LHS packing (threaded over m, honoring mr alignment and KV groups)
247 {
248 const int64_t m_roundup_mr = kai_roundup(m, mr);
249 const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
250
251 if (ith < num_threads) {
252 const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
253 const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
254
255 const int64_t m_start = ith * num_m_per_thread0;
256 const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
257
258 // Base packed offset (aligned) and per-row stride in bytes
259 const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
260 const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
261 const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
262
263 int64_t remaining = m_count;
264 int64_t cur = m_start;
265
266 while (remaining > 0) {
267 const int64_t row_in_group = cur;
268 const int64_t avail = m_group - row_in_group;
269 const int64_t take = std::min(avail, remaining);
270
271 const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
272 const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
273 const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
274 void * dst_ptr = lhs_packed + dst_off;
275
276 lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
277
278 cur += take;
279 remaining -= take;
280 }
281 }
282 }
283
284 // RHS packing (single thread), then synchronize
285 if (ith == 0) {
286 memset(bias, 0, (size_t)n * sizeof(float));
287 transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
288 reinterpret_cast<float *>(rhs_kxn),
289 reinterpret_cast<const uint16_t *>(rhs_batch_base),
290 rhs_stride);
291
292 kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
293 rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
294 }
295
296 ggml_barrier(params->threadpool);
297
298 // Matmul (threaded over n)
299 {
300 const int64_t n_step = (int64_t) kernel->get_n_step();
301 int64_t num_threads_n = KAI_MIN(n / n_step, nth);
302 if (num_threads_n <= 0) {
303 num_threads_n = 1;
304 }
305
306 if (ith < num_threads_n) {
307 const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
308 const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
309
310 const int64_t n_start = ith * num_n_per_thread0;
311 const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
312
313 // LHS packed base at row 0 (consistent with packing above)
314 const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
315 const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
316 const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
317
318 const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
319 const void * rhs_ptr = rhs_packed + rhs_packed_offset;
320 float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
321
322 kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
323 }
324 }
325
326 if (batch_idx != batch_size - 1) {
327 ggml_barrier(params->threadpool);
328 }
329 }
330
331 return true;
332 }
333
334 bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
335 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
336
337 const ggml_tensor * src0 = dst->src[0];
338 const ggml_tensor * src1 = dst->src[1];
339
340 GGML_TENSOR_BINARY_OP_LOCALS
341
342 ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
343 if (!kernels) {
344 return false;
345 }
346
347 bool is_gemv = src1->ne[1] == 1;
348 kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
349 lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
350
351 GGML_ASSERT(kernel);
352 if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
353 !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
354 return false;
355 }
356
357 const int ith = params->ith;
358 const int nth_raw = params->nth;
359 const int nth = nth_raw > 0 ? nth_raw : 1;
360
361 const size_t k = ne00;
362 const size_t m = ne11;
363 const size_t n = ne01;
364
365 size_t mr = kernel->get_mr();
366 size_t kr = kernel->get_kr();
367 size_t sr = kernel->get_sr();
368
369 const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
370 uint8_t * lhs_packed = (uint8_t*)params->wdata;
371 const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
372
373 const size_t n_step = kernel->get_n_step();
374 const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
375 const size_t n_start = ith * num_n_per_thread;
376
377 size_t n_to_process = 0;
378 if (n_start < n) {
379 n_to_process = num_n_per_thread;
380 if ((n_start + n_to_process) > n) {
381 n_to_process = n - n_start;
382 }
383 }
384
385 // Calculate number of columns to be processed per thread
386 const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
387 const size_t m_start = ith * num_m_per_thread;
388 size_t m_to_process = num_m_per_thread;
389 if ((m_start + m_to_process) > m) {
390 m_to_process = m - m_start;
391 }
392
393 if (m_start < m) {
394 // Transform LHS
395 const size_t src_stride = src1->nb[1];
396 const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
397 const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
398 void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
399
400 // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
401 lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
402 }
403
404 ggml_barrier(params->threadpool);
405
406 // Perform the operation
407 const size_t dst_stride = dst->nb[1];
408 const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
409 const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
410 const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
411 const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
412 const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
413 float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
414
415 if (n_to_process > 0) {
416 kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
417 sizeof(float), -FLT_MAX, FLT_MAX);
418 }
419
420 return true;
421 }
422
423 bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
424 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
425
426 const ggml_tensor * src0 = dst->src[0];
427 const ggml_tensor * src1 = dst->src[1];
428
429 GGML_TENSOR_BINARY_OP_LOCALS
430
431 ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
432 if (!kernels) {
433 return false;
434 }
435
436 bool is_gemv = src1->ne[1] == 1;
437 kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
438 lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
439
440 if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
441 !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
442 return false;
443 }
444
445 const int ith = params->ith;
446 const int nth_raw = params->nth;
447 const int nth = nth_raw > 0 ? nth_raw : 1;
448
449 const size_t k = ne00;
450 const size_t m = ne11;
451 const size_t n = ne01;
452
453 size_t mr = kernel->get_mr();
454 size_t kr = kernel->get_kr();
455 size_t sr = kernel->get_sr();
456
457 const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
458 uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
459 const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
460
461 const size_t n_step = kernel->get_n_step();
462 const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
463 const size_t n_start = ith * num_n_per_thread;
464
465 size_t n_to_process = 0;
466 if (n_start < n) {
467 n_to_process = num_n_per_thread;
468 if ((n_start + n_to_process) > n) {
469 n_to_process = n - n_start;
470 }
471 }
472
473 const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
474 const size_t m_start = ith * num_m_per_thread;
475 size_t m_to_process = num_m_per_thread;
476 if ((m_start + m_to_process) > m) {
477 m_to_process = m - m_start;
478 }
479
480 if (m_start < m) {
481 const size_t src_stride = src1->nb[1];
482 const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
483 const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
484 void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
485
486 lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
487 }
488
489 ggml_barrier(params->threadpool);
490
491 const size_t dst_stride = dst->nb[1];
492 const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
493 const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
494 const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
495 const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
496 const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
497 float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
498
499 if (n_to_process > 0) {
500 kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
501 sizeof(float), -FLT_MAX, FLT_MAX);
502 }
503
504 return true;
505 }
506
507 bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
508 const ggml_tensor * src0 = dst->src[0];
509 const ggml_tensor * src1 = dst->src[1];
510
511 GGML_TENSOR_BINARY_OP_LOCALS
512
513 ggml_kleidiai_kernels * kernels = nullptr;
514 size_t block_len = 0;
515 size_t num_bytes_multiplier = 0;
516
517 if (dst->src[0]->type == GGML_TYPE_Q4_0) {
518 if (!ctx.kernels_q4) {
519 return false;
520 }
521 kernels = ctx.kernels_q4;
522 block_len = QK4_0;
523 num_bytes_multiplier = sizeof(uint16_t);
524 } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
525 if (!ctx.kernels_q8) {
526 return false;
527 }
528 kernels = ctx.kernels_q8;
529 block_len = QK8_0;
530 num_bytes_multiplier = sizeof(float);
531 } else {
532 return false;
533 }
534
535 rhs_packing_info * rhs_info = &kernels->rhs_info;
536 kernel_info * kernel = &kernels->gemm;
537 if (!rhs_info->to_float || !kernel->get_nr) {
538 return false;
539 }
540
541 const int64_t nc = ne00;
542 const int64_t nr = ggml_nelements(src1);
543
544 const size_t block_rows = kernel->get_nr();
545 const size_t kr = kernel->get_kr();
546
547 const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
548
549 const int ith = params->ith;
550 const int nth = params->nth;
551
552 const int dr = (nr + nth - 1) / nth;
553 const int ir0 = dr * ith;
554 const int ir1 = MIN(ir0 + dr, nr);
555
556 for (int64_t i = ir0; i < ir1; ++i) {
557 GGML_ASSERT(src1->type == GGML_TYPE_I32);
558 int64_t row_idx = ((const int32_t *)src1->data)[i];
559 GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
560
561 float *out = (float *)((char *)dst->data + i * nb1);
562 rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
563 }
564
565 return true;
566 }
567
568public:
569 int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
570 const size_t n = tensor->ne[1];
571 const size_t k = tensor->ne[0];
572
573 if (tensor->type == GGML_TYPE_Q4_0) {
574 if (!ctx.kernels_q4) {
575 return -1;
576 }
577 size_t nr = ctx.kernels_q4->gemm.get_nr();
578 size_t kr = ctx.kernels_q4->gemm.get_kr();
579 size_t sr = ctx.kernels_q4->gemm.get_sr();
580
581 struct kai_rhs_pack_qs4cxs1s0_param params;
582 params.lhs_zero_point = 1;
583 params.rhs_zero_point = 8;
584 ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
585 static_cast<const uint8_t *>(data),
586 nullptr, nullptr, tensor->data, 0, ¶ms);
587 GGML_UNUSED(data_size);
588 return 0;
589 } else if (tensor->type == GGML_TYPE_Q8_0) {
590 if (!ctx.kernels_q8) {
591 return -1;
592 }
593
594 const size_t row_stride = tensor->nb[1];
595 const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
596
597 std::vector<int8_t> qdata(n * k, 0);
598 std::vector<float> scales(n, 0.0f);
599
600 for (size_t row = 0; row < n; ++row) {
601 const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
602 static_cast<const uint8_t *>(data) + row * row_stride);
603
604 float max_abs = 0.0f;
605 for (size_t block = 0; block < k_blocks; ++block) {
606 const block_q8_0 & blk = row_blocks[block];
607 const float d = GGML_FP16_TO_FP32(blk.d);
608 for (size_t l = 0; l < QK8_0; ++l) {
609 const size_t linear_idx = block * QK8_0 + l;
610 if (linear_idx >= k) {
611 break;
612 }
613 const float value = d * blk.qs[l];
614 max_abs = std::max(max_abs, std::fabs(value));
615 }
616 }
617
618 float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
619 scales[row] = scale;
620 const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
621
622 for (size_t block = 0; block < k_blocks; ++block) {
623 const block_q8_0 & blk = row_blocks[block];
624 const float d = GGML_FP16_TO_FP32(blk.d);
625 for (size_t l = 0; l < QK8_0; ++l) {
626 const size_t linear_idx = block * QK8_0 + l;
627 if (linear_idx >= k) {
628 break;
629 }
630 const float value = d * blk.qs[l];
631 int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
632 q = std::clamp(q, -127, 127);
633 qdata[row * k + linear_idx] = static_cast<int8_t>(q);
634 }
635 }
636 }
637
638 size_t nr = ctx.kernels_q8->gemm.get_nr();
639 size_t kr = ctx.kernels_q8->gemm.get_kr();
640 size_t sr = ctx.kernels_q8->gemm.get_sr();
641
642 struct kai_rhs_pack_qsi8cx_params params;
643 params.lhs_zero_point = 1;
644 params.scale_multiplier = 1.0f;
645
646 ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
647 qdata.data(), nullptr, scales.data(),
648 tensor->data, 0, ¶ms);
649 GGML_UNUSED(data_size);
650 return 0;
651 }
652
653 GGML_UNUSED(data_size);
654 return -1;
655 }
656};
657
658static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
659 static tensor_traits traits;
660 return &traits;
661}
662} // namespace ggml::cpu::kleidiai
663
664static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
665 tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
666
667 return GGML_STATUS_SUCCESS;
668 GGML_UNUSED(buffer);
669}
670
671static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
672 const void * data, size_t offset, size_t size) {
673 GGML_ASSERT(offset == 0);
674 GGML_ASSERT(size == ggml_nbytes(tensor));
675
676 auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
677 auto OK = tensor_traits->repack(tensor, data, size);
678
679 GGML_ASSERT(OK == 0);
680 GGML_UNUSED(buffer);
681}
682
683static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
684 return "CPU_KLEIDIAI";
685
686 GGML_UNUSED(buft);
687}
688
689static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
690 ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
691
692 if (buffer == nullptr) {
693 return nullptr;
694 }
695
696 buffer->buft = buft;
697 buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
698 buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
699 buffer->iface.get_tensor = nullptr;
700 buffer->iface.cpy_tensor = nullptr;
701 return buffer;
702}
703
704static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
705 return TENSOR_ALIGNMENT;
706
707 GGML_UNUSED(buft);
708}
709
710static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
711 GGML_UNUSED(buft);
712
713 const size_t n = tensor->ne[1];
714 const size_t k = tensor->ne[0];
715
716 ggml_kleidiai_kernels * kernels = nullptr;
717 size_t block_len = 0;
718
719 if (tensor->type == GGML_TYPE_Q4_0) {
720 GGML_ASSERT(ctx.kernels_q4);
721 kernels = ctx.kernels_q4;
722 block_len = QK4_0;
723 } else if (tensor->type == GGML_TYPE_Q8_0) {
724 GGML_ASSERT(ctx.kernels_q8);
725 kernels = ctx.kernels_q8;
726 block_len = QK8_0;
727 } else {
728 return 0;
729 }
730
731 const size_t nr = kernels->gemm.get_nr();
732 const size_t kr = kernels->gemm.get_kr();
733 const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
734 const size_t raw = ggml_nbytes(tensor);
735
736 return packed > raw ? packed : raw;
737}
738
739namespace ggml::cpu::kleidiai {
740class extra_buffer_type : ggml::cpu::extra_buffer_type {
741 bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
742 if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
743 (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
744 op->src[0]->buffer &&
745 (ggml_n_dims(op->src[0]) == 2) &&
746 op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
747 if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
748 return false;
749 }
750 if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
751 return false;
752 }
753 if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
754 ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
755 return true;
756 }
757 }
758 return false;
759 }
760
761 ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
762 if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
763 if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
764 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
765 }
766 else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
767 if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
768 (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
769 return nullptr;
770 }
771
772 return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
773 }
774 }
775 return nullptr;
776 }
777};
778} // namespace ggml::cpu::kleidiai
779
780ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
781 static ggml::cpu::kleidiai::extra_buffer_type ctx;
782 static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
783 /* .iface = */ {
784 /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
785 /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
786 /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
787 /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
788 /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
789 /* .is_host = */ nullptr,
790 },
791 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
792 /* .context = */ &ctx,
793 };
794
795 init_kleidiai_context();
796
797 return &ggml_backend_cpu_buffer_type_kleidiai;
798}