1// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
  2// SPDX-License-Identifier: MIT
  3//
  4#include <arm_neon.h>
  5#include <assert.h>
  6#include <atomic>
  7#include <cfloat>
  8#include <cmath>
  9#include <algorithm>
 10#include <stdexcept>
 11#include <stdint.h>
 12#include <string.h>
 13#include <string>
 14#include <vector>
 15#if defined(__linux__)
 16#include <asm/hwcap.h>
 17#include <sys/auxv.h>
 18#elif defined(__APPLE__)
 19#include <string_view>
 20#include <sys/sysctl.h>
 21#include <sys/types.h>
 22#elif defined(_WIN32)
 23#include <windows.h>
 24#include <excpt.h>
 25#endif
 26
 27#include "kleidiai.h"
 28
 29#include "ggml-cpu.h"
 30#include "ggml-impl.h"
 31#include "ggml-backend-impl.h"
 32#include "ggml-threading.h"
 33#include "traits.h"
 34
 35#include "kernels.h"
 36
 37#include "kai_common.h"
 38
 39#define GGML_COMMON_DECL_CPP
 40#include "ggml-common.h"
 41
 42struct ggml_kleidiai_context {
 43    cpu_feature features;
 44    ggml_kleidiai_kernels * kernels_q4;
 45    ggml_kleidiai_kernels * kernels_q8;
 46} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
 47
 48static const char* cpu_feature_to_string(cpu_feature f) {
 49    if (f == CPU_FEATURE_NONE) {
 50        return "NONE";
 51    } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
 52        return "SME";
 53    } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
 54        return "SVE";
 55    }
 56    else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
 57        return "I8MM";
 58    } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
 59        return "DOTPROD";
 60    }
 61    else {
 62        return "UNKNOWN";
 63    }
 64}
 65
 66static void init_kleidiai_context(void) {
 67
 68    ggml_critical_section_start();
 69    static bool initialized = false;
 70
 71    if (!initialized) {
 72        initialized = true;
 73        const char *env_var = getenv("GGML_KLEIDIAI_SME");
 74        int sme_enabled = 0;
 75
 76        ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
 77                        (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
 78                        ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
 79
 80        if (env_var) {
 81            sme_enabled = atoi(env_var);
 82        }
 83
 84        if (sme_enabled != 0) {
 85            ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
 86        }
 87        ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
 88        ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
 89#ifndef NDEBUG
 90        if (ctx.kernels_q4) {
 91            GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
 92        }
 93        if (ctx.kernels_q8) {
 94            GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
 95        }
 96#endif
 97    }
 98    ggml_critical_section_end();
 99}
100
101static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
102    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
103    return tensor->ne[dim];
104}
105
106namespace ggml::cpu::kleidiai {
107
108static size_t round_down(size_t x, size_t y) {
109    return y == 0 ? x : x - (x % y);
110}
111
112static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
113    size_t src_stride = rhs_stride / sizeof(uint16_t);
114    size_t dst_stride = n;
115
116    for (size_t k_idx = 0; k_idx < k; ++k_idx) {
117        for (size_t n_idx = 0; n_idx < n; ++n_idx) {
118            uint16_t v = *(src + k_idx + n_idx * src_stride);
119            *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
120        }
121    }
122}
123
124class tensor_traits : public ggml::cpu::tensor_traits {
125    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
126        if (op->op != GGML_OP_MUL_MAT) {
127            return false;
128        }
129        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
130        if (!kernels) {
131            return false;
132        }
133        bool is_gemv = op->src[1]->ne[1] == 1;
134        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
135        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
136
137        size_t k = op->src[0]->ne[0];
138        size_t n = op->src[0]->ne[1];
139        size_t m = op->src[1]->ne[1];
140
141        size_t mr = kernel->get_mr();
142        size_t kr = kernel->get_kr();
143        size_t sr = kernel->get_sr();
144
145        if (kernels->rhs_type == GGML_TYPE_Q4_0) {
146            if (!lhs_info->packed_size_ex) return false;
147            size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
148        } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
149            if (!lhs_info->packed_size_ex) return false;
150            size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
151        } else if (kernels->rhs_type == GGML_TYPE_F16) {
152            if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
153            const int64_t lhs_batch_size0 = op->src[1]->ne[2];
154            const int64_t rhs_batch_size0 = op->src[0]->ne[2];
155            const int64_t r = lhs_batch_size0 / rhs_batch_size0;
156            size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
157                   kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
158                   k * n * sizeof(float) + n * sizeof(float);
159        } else {
160            return false;
161        }
162
163        return true;
164    }
165
166    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
167        if (dst->op == GGML_OP_MUL_MAT) {
168            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
169                return compute_forward_q4_0(params, dst);
170            } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
171                return compute_forward_q8_0(params, dst);
172            } else if (dst->src[0]->type == GGML_TYPE_F16) {
173                return compute_forward_fp16(params, dst);
174            }
175        } else if (dst->op == GGML_OP_GET_ROWS) {
176            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
177                return compute_forward_get_rows(params, dst);
178            }
179        }
180        return false;
181    }
182
183    bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
184        const ggml_tensor * src0 = dst->src[0];
185        const ggml_tensor * src1 = dst->src[1];
186
187        GGML_TENSOR_BINARY_OP_LOCALS
188
189        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
190        if (!kernels) {
191            return false;
192        }
193
194        const bool is_gemv = src1->ne[1] == 1;
195        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
196        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
197        GGML_ASSERT(kernel);
198        if (!kernels->rhs_info.pack_func_ex ||
199            !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
200            return false;
201        }
202
203        const int nth = params->nth;
204        const int ith = params->ith;
205
206        const int64_t lhs_batch_size0 = ne12;
207        const int64_t rhs_batch_size0 = ne02;
208        const int64_t batch_size      = lhs_batch_size0;
209
210        GGML_ASSERT(rhs_batch_size0 > 0);
211        GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
212        const int64_t r = lhs_batch_size0 / rhs_batch_size0;
213
214        const int64_t m_group = ne11;
215        const int64_t m       = m_group;
216        const int64_t n       = ne01;
217        const int64_t k       = ne00;
218
219        const size_t lhs_stride = src1->nb[1];
220        const size_t rhs_stride = src0->nb[1];
221        const size_t dst_stride = dst->nb[1];
222
223        const int64_t mr = (int64_t) kernel->get_mr();
224        const int64_t nr = (int64_t) kernel->get_nr();
225        const int64_t kr = (int64_t) kernel->get_kr();
226        const int64_t sr = (int64_t) kernel->get_sr();
227
228        const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
229        const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
230        const size_t kxn_size        = k * n * sizeof(float);
231        const size_t bias_size       = n * sizeof(float);
232
233        const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
234        GGML_ASSERT(wsize_required <= params->wsize);
235
236        uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
237        uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
238        uint8_t * rhs_kxn    = rhs_packed + rhs_packed_size;
239        uint8_t * bias       = rhs_kxn + kxn_size;
240
241        for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
242            const int64_t rhs_batch_idx = batch_idx / r;
243            const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
244            uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
245
246            // LHS packing (threaded over m, honoring mr alignment and KV groups)
247            {
248                const int64_t m_roundup_mr = kai_roundup(m, mr);
249                const int64_t num_threads  = KAI_MIN(m_roundup_mr / mr, nth);
250
251                if (ith < num_threads) {
252                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
253                    const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
254
255                    const int64_t m_start = ith * num_m_per_thread0;
256                    const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
257
258                    // Base packed offset (aligned) and per-row stride in bytes
259                    const size_t base_packed_off  = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
260                    const size_t next_block_off   = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
261                    const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
262
263                    int64_t remaining = m_count;
264                    int64_t cur       = m_start;
265
266                    while (remaining > 0) {
267                        const int64_t row_in_group = cur;
268                        const int64_t avail        = m_group - row_in_group;
269                        const int64_t take         = std::min(avail, remaining);
270
271                        const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
272                        const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
273                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
274                        void * dst_ptr       = lhs_packed + dst_off;
275
276                        lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
277
278                        cur       += take;
279                        remaining -= take;
280                    }
281                }
282            }
283
284            // RHS packing (single thread), then synchronize
285            if (ith == 0) {
286                memset(bias, 0, (size_t)n * sizeof(float));
287                transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
288                                        reinterpret_cast<float *>(rhs_kxn),
289                                        reinterpret_cast<const uint16_t *>(rhs_batch_base),
290                                        rhs_stride);
291
292                kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
293                             rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
294            }
295
296            ggml_barrier(params->threadpool);
297
298            // Matmul (threaded over n)
299            {
300                const int64_t n_step  = (int64_t) kernel->get_n_step();
301                int64_t num_threads_n = KAI_MIN(n / n_step, nth);
302                if (num_threads_n <= 0) {
303                    num_threads_n = 1;
304                }
305
306                if (ith < num_threads_n) {
307                    const int64_t num_n_per_thread0   = round_down((size_t)(n / num_threads_n), (size_t)n_step);
308                    const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
309
310                    const int64_t n_start      = ith * num_n_per_thread0;
311                    const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
312
313                    // LHS packed base at row 0 (consistent with packing above)
314                    const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
315                    const size_t rhs_packed_offset  = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
316                    const size_t dst_offset         = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
317
318                    const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
319                    const void * rhs_ptr = rhs_packed + rhs_packed_offset;
320                    float * dst_ptr      = reinterpret_cast<float *>(dst_batch_base + dst_offset);
321
322                    kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
323                }
324            }
325
326            if (batch_idx != batch_size - 1) {
327                ggml_barrier(params->threadpool);
328            }
329        }
330
331        return true;
332    }
333
334    bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
335        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
336
337        const ggml_tensor * src0 = dst->src[0];
338        const ggml_tensor * src1 = dst->src[1];
339
340        GGML_TENSOR_BINARY_OP_LOCALS
341
342        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
343        if (!kernels) {
344            return false;
345        }
346
347        bool is_gemv = src1->ne[1] == 1;
348        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
349        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
350
351        GGML_ASSERT(kernel);
352        if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
353            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
354            return false;
355        }
356
357        const int ith = params->ith;
358        const int nth_raw = params->nth;
359        const int nth = nth_raw > 0 ? nth_raw : 1;
360
361        const size_t k = ne00;
362        const size_t m = ne11;
363        const size_t n = ne01;
364
365        size_t mr = kernel->get_mr();
366        size_t kr = kernel->get_kr();
367        size_t sr = kernel->get_sr();
368
369        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
370        uint8_t * lhs_packed       = (uint8_t*)params->wdata;
371        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
372
373        const size_t n_step = kernel->get_n_step();
374        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
375        const size_t n_start = ith * num_n_per_thread;
376
377        size_t n_to_process = 0;
378        if (n_start < n) {
379            n_to_process = num_n_per_thread;
380            if ((n_start + n_to_process) > n) {
381                n_to_process = n - n_start;
382            }
383        }
384
385        // Calculate number of columns to be processed per thread
386        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
387        const size_t m_start = ith * num_m_per_thread;
388        size_t m_to_process = num_m_per_thread;
389        if ((m_start + m_to_process) > m) {
390            m_to_process = m - m_start;
391        }
392
393        if (m_start < m) {
394            // Transform LHS
395            const size_t src_stride        = src1->nb[1];
396            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
397            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
398            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
399
400            // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
401            lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
402        }
403
404        ggml_barrier(params->threadpool);
405
406        // Perform the operation
407        const size_t dst_stride        = dst->nb[1];
408        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
409        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
410        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
411        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
412        const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
413        float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
414
415        if (n_to_process > 0) {
416            kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
417                               sizeof(float), -FLT_MAX, FLT_MAX);
418        }
419
420        return true;
421    }
422
423    bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
424        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
425
426        const ggml_tensor * src0 = dst->src[0];
427        const ggml_tensor * src1 = dst->src[1];
428
429        GGML_TENSOR_BINARY_OP_LOCALS
430
431        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
432        if (!kernels) {
433            return false;
434        }
435
436        bool is_gemv = src1->ne[1] == 1;
437        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
438        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
439
440        if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
441            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
442            return false;
443        }
444
445        const int ith = params->ith;
446        const int nth_raw = params->nth;
447        const int nth = nth_raw > 0 ? nth_raw : 1;
448
449        const size_t k = ne00;
450        const size_t m = ne11;
451        const size_t n = ne01;
452
453        size_t mr = kernel->get_mr();
454        size_t kr = kernel->get_kr();
455        size_t sr = kernel->get_sr();
456
457        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
458        uint8_t * lhs_packed       = static_cast<uint8_t *>(params->wdata);
459        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
460
461        const size_t n_step = kernel->get_n_step();
462        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
463        const size_t n_start = ith * num_n_per_thread;
464
465        size_t n_to_process = 0;
466        if (n_start < n) {
467            n_to_process = num_n_per_thread;
468            if ((n_start + n_to_process) > n) {
469                n_to_process = n - n_start;
470            }
471        }
472
473        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
474        const size_t m_start = ith * num_m_per_thread;
475        size_t m_to_process = num_m_per_thread;
476        if ((m_start + m_to_process) > m) {
477            m_to_process = m - m_start;
478        }
479
480        if (m_start < m) {
481            const size_t src_stride        = src1->nb[1];
482            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
483            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
484            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
485
486            lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
487        }
488
489        ggml_barrier(params->threadpool);
490
491        const size_t dst_stride        = dst->nb[1];
492        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
493        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
494        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
495        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
496        const void * lhs_ptr           = static_cast<const void *>(lhs_packed + lhs_packed_offset);
497        float * dst_ptr                = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
498
499        if (n_to_process > 0) {
500            kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
501                                  sizeof(float), -FLT_MAX, FLT_MAX);
502        }
503
504        return true;
505    }
506
507    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
508        const ggml_tensor * src0 = dst->src[0];
509        const ggml_tensor * src1 = dst->src[1];
510
511        GGML_TENSOR_BINARY_OP_LOCALS
512
513        ggml_kleidiai_kernels * kernels = nullptr;
514        size_t block_len = 0;
515        size_t num_bytes_multiplier = 0;
516
517        if (dst->src[0]->type == GGML_TYPE_Q4_0) {
518            if (!ctx.kernels_q4) {
519                return false;
520            }
521            kernels = ctx.kernels_q4;
522            block_len = QK4_0;
523            num_bytes_multiplier = sizeof(uint16_t);
524        } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
525            if (!ctx.kernels_q8) {
526                return false;
527            }
528            kernels = ctx.kernels_q8;
529            block_len = QK8_0;
530            num_bytes_multiplier = sizeof(float);
531        } else {
532            return false;
533        }
534
535        rhs_packing_info * rhs_info = &kernels->rhs_info;
536        kernel_info * kernel        = &kernels->gemm;
537        if (!rhs_info->to_float || !kernel->get_nr) {
538            return false;
539        }
540
541        const int64_t nc     = ne00;
542        const int64_t nr     = ggml_nelements(src1);
543
544        const size_t block_rows = kernel->get_nr();
545        const size_t kr         = kernel->get_kr();
546
547        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
548
549        const int ith = params->ith;
550        const int nth = params->nth;
551
552        const int dr = (nr + nth - 1) / nth;
553        const int ir0 = dr * ith;
554        const int ir1 = MIN(ir0 + dr, nr);
555
556        for (int64_t i = ir0; i < ir1; ++i) {
557            GGML_ASSERT(src1->type == GGML_TYPE_I32);
558            int64_t row_idx = ((const int32_t *)src1->data)[i];
559            GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
560
561            float *out = (float *)((char *)dst->data + i * nb1);
562            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
563        }
564
565        return true;
566    }
567
568public:
569    int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
570        const size_t n = tensor->ne[1];
571        const size_t k = tensor->ne[0];
572
573        if (tensor->type == GGML_TYPE_Q4_0) {
574            if (!ctx.kernels_q4) {
575                return -1;
576            }
577            size_t nr = ctx.kernels_q4->gemm.get_nr();
578            size_t kr = ctx.kernels_q4->gemm.get_kr();
579            size_t sr = ctx.kernels_q4->gemm.get_sr();
580
581            struct kai_rhs_pack_qs4cxs1s0_param params;
582            params.lhs_zero_point = 1;
583            params.rhs_zero_point = 8;
584            ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
585                                                  static_cast<const uint8_t *>(data),
586                                                  nullptr, nullptr, tensor->data, 0, &params);
587            GGML_UNUSED(data_size);
588            return 0;
589        } else if (tensor->type == GGML_TYPE_Q8_0) {
590            if (!ctx.kernels_q8) {
591                return -1;
592            }
593
594            const size_t row_stride = tensor->nb[1];
595            const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;
596
597            std::vector<int8_t> qdata(n * k, 0);
598            std::vector<float> scales(n, 0.0f);
599
600            for (size_t row = 0; row < n; ++row) {
601                const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
602                    static_cast<const uint8_t *>(data) + row * row_stride);
603
604                float max_abs = 0.0f;
605                for (size_t block = 0; block < k_blocks; ++block) {
606                    const block_q8_0 & blk = row_blocks[block];
607                    const float d = GGML_FP16_TO_FP32(blk.d);
608                    for (size_t l = 0; l < QK8_0; ++l) {
609                        const size_t linear_idx = block * QK8_0 + l;
610                        if (linear_idx >= k) {
611                            break;
612                        }
613                        const float value = d * blk.qs[l];
614                        max_abs = std::max(max_abs, std::fabs(value));
615                    }
616                }
617
618                float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
619                scales[row] = scale;
620                const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
621
622                for (size_t block = 0; block < k_blocks; ++block) {
623                    const block_q8_0 & blk = row_blocks[block];
624                    const float d = GGML_FP16_TO_FP32(blk.d);
625                    for (size_t l = 0; l < QK8_0; ++l) {
626                        const size_t linear_idx = block * QK8_0 + l;
627                        if (linear_idx >= k) {
628                            break;
629                        }
630                        const float value = d * blk.qs[l];
631                        int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
632                        q = std::clamp(q, -127, 127);
633                        qdata[row * k + linear_idx] = static_cast<int8_t>(q);
634                    }
635                }
636            }
637
638            size_t nr = ctx.kernels_q8->gemm.get_nr();
639            size_t kr = ctx.kernels_q8->gemm.get_kr();
640            size_t sr = ctx.kernels_q8->gemm.get_sr();
641
642            struct kai_rhs_pack_qsi8cx_params params;
643            params.lhs_zero_point = 1;
644            params.scale_multiplier = 1.0f;
645
646            ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
647                                                  qdata.data(), nullptr, scales.data(),
648                                                  tensor->data, 0, &params);
649            GGML_UNUSED(data_size);
650            return 0;
651        }
652
653        GGML_UNUSED(data_size);
654        return -1;
655    }
656};
657
658static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
659    static tensor_traits traits;
660    return &traits;
661}
662}  // namespace ggml::cpu::kleidiai
663
664static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
665    tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
666
667    return GGML_STATUS_SUCCESS;
668    GGML_UNUSED(buffer);
669}
670
671static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
672                                                       const void * data, size_t offset, size_t size) {
673    GGML_ASSERT(offset == 0);
674    GGML_ASSERT(size == ggml_nbytes(tensor));
675
676    auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
677    auto OK            = tensor_traits->repack(tensor, data, size);
678
679    GGML_ASSERT(OK == 0);
680    GGML_UNUSED(buffer);
681}
682
683static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
684    return "CPU_KLEIDIAI";
685
686    GGML_UNUSED(buft);
687}
688
689static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
690    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
691
692    if (buffer == nullptr) {
693        return nullptr;
694    }
695
696    buffer->buft              = buft;
697    buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
698    buffer->iface.set_tensor  = ggml_backend_cpu_kleidiai_buffer_set_tensor;
699    buffer->iface.get_tensor  = nullptr;
700    buffer->iface.cpy_tensor  = nullptr;
701    return buffer;
702}
703
704static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
705    return TENSOR_ALIGNMENT;
706
707    GGML_UNUSED(buft);
708}
709
710static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
711    GGML_UNUSED(buft);
712
713    const size_t n = tensor->ne[1];
714    const size_t k = tensor->ne[0];
715
716    ggml_kleidiai_kernels * kernels = nullptr;
717    size_t block_len = 0;
718
719    if (tensor->type == GGML_TYPE_Q4_0) {
720        GGML_ASSERT(ctx.kernels_q4);
721        kernels = ctx.kernels_q4;
722        block_len = QK4_0;
723    } else if (tensor->type == GGML_TYPE_Q8_0) {
724        GGML_ASSERT(ctx.kernels_q8);
725        kernels = ctx.kernels_q8;
726        block_len = QK8_0;
727    } else {
728        return 0;
729    }
730
731    const size_t nr = kernels->gemm.get_nr();
732    const size_t kr = kernels->gemm.get_kr();
733    const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
734    const size_t raw     = ggml_nbytes(tensor);
735
736    return packed > raw ? packed : raw;
737}
738
739namespace ggml::cpu::kleidiai {
740class extra_buffer_type : ggml::cpu::extra_buffer_type {
741    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
742        if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
743            (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
744            op->src[0]->buffer &&
745            (ggml_n_dims(op->src[0]) == 2) &&
746            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
747            if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
748                return false;
749            }
750            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
751                return false;
752            }
753            if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
754                ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
755                return true;
756            }
757        }
758        return false;
759    }
760
761    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
762        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
763            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
764                return (ggml::cpu::tensor_traits *) op->src[0]->extra;
765            }
766            else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
767                if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
768                    (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
769                    return nullptr;
770                }
771
772                return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
773            }
774        }
775        return nullptr;
776    }
777};
778}  // namespace ggml::cpu::kleidiai
779
780ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
781    static ggml::cpu::kleidiai::extra_buffer_type ctx;
782    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
783        /* .iface    = */ {
784                           /* .get_name         = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
785                           /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
786                           /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
787                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
788                           /* .get_alloc_size   = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
789                           /* .is_host          = */ nullptr,
790                           },
791        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
792        /* .context = */ &ctx,
793    };
794
795    init_kleidiai_context();
796
797    return &ggml_backend_cpu_buffer_type_kleidiai;
798}