1#define GGML_COMMON_IMPL_CPP
   2#define GGML_COMMON_DECL_CPP
   3#include "ggml-common.h"
   4#include "ggml-backend-impl.h"
   5
   6#include "ggml-impl.h"
   7#include "ggml-cpu.h"
   8#include "ggml-cpu-impl.h"
   9#include "simd-mappings.h"
  10#include "traits.h"
  11
  12#include <cmath>
  13#include <cstring>
  14#include <cassert>
  15#include <cstdlib> // for qsort
  16#include <cstdio>  // for GGML_ASSERT
  17
  18#define GGML_CPU_CLANG_WORKAROUND
  19#include "../../repack.h"
  20
  21#if defined(__GNUC__)
  22#pragma GCC diagnostic ignored "-Woverlength-strings"
  23#endif
  24
  25#define UNUSED GGML_UNUSED
  26
  27#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
  28// Helper for decoding scales and mins of Q4_K and Q5_K block formats
  29static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
  30    constexpr uint32_t kmask1 = 0x3f3f3f3f;
  31    constexpr uint32_t kmask2 = 0x0f0f0f0f;
  32    constexpr uint32_t kmask3 = 0x03030303;
  33    constexpr uint8_t  scales_size = 12;
  34
  35    uint32_t sm[3];
  36    memcpy(sm, scales_in, scales_size);
  37
  38    const uint32_t   mins_0_3 = sm[1] & kmask1;
  39    const uint32_t   mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
  40    const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
  41
  42    *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
  43
  44    uint32_t scales_u32[2];
  45    scales_u32[0] = sm[0] & kmask1;
  46    scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
  47    memcpy(out_scales, scales_u32, 8);
  48}
  49#endif
  50
  51void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  52    assert(QK8_0 == 32);
  53    assert(k % QK8_0 == 0);
  54    const int nb = k / QK8_0;
  55
  56    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
  57
  58#if defined(__ARM_NEON)
  59    float32x4_t srcv[4][8];
  60    float id[4];
  61
  62    for (int i = 0; i < nb; i++) {
  63        float32x4_t asrcv[8];
  64        float32x4_t amaxv[8];
  65
  66        for (int row_iter = 0; row_iter < 4; row_iter++) {
  67            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
  68            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
  69
  70            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
  71            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
  72            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
  73
  74            const float amax = vmaxvq_f32(amaxv[0]);
  75
  76            const float d = amax / ((1 << 7) - 1);
  77            id[row_iter] = d ? 1.0f / d : 0.0f;
  78
  79            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
  80        }
  81
  82        for (int j = 0; j < 8; j++) {
  83            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
  84            int32x4_t vi = vcvtnq_s32_f32(v);
  85            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
  86            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
  87            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
  88            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
  89
  90            v = vmulq_n_f32(srcv[1][j], id[1]);
  91            vi = vcvtnq_s32_f32(v);
  92            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
  93            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
  94            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
  95            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
  96
  97            v = vmulq_n_f32(srcv[2][j], id[2]);
  98            vi = vcvtnq_s32_f32(v);
  99            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
 100            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
 101            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
 102            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
 103
 104            v = vmulq_n_f32(srcv[3][j], id[3]);
 105            vi = vcvtnq_s32_f32(v);
 106            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
 107            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
 108            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
 109            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
 110        }
 111    }
 112#else
 113    UNUSED(nb);
 114    UNUSED(y);
 115    ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
 116#endif
 117}
 118
 119void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
 120    assert(QK8_0 == 32);
 121    assert(k % QK8_0 == 0);
 122    const int nb = k / QK8_0;
 123
 124    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
 125
 126#if defined(__ARM_NEON)
 127    float32x4_t srcv[4][8];
 128    float id[4];
 129
 130    for (int i = 0; i < nb; i++) {
 131        float32x4_t asrcv[8];
 132        float32x4_t amaxv[8];
 133
 134        for (int row_iter = 0; row_iter < 4; row_iter++) {
 135            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
 136            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
 137
 138            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
 139            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
 140            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
 141
 142            const float amax = vmaxvq_f32(amaxv[0]);
 143
 144            const float d = amax / ((1 << 7) - 1);
 145            id[row_iter] = d ? 1.0f / d : 0.0f;
 146
 147            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
 148        }
 149
 150        for (int j = 0; j < 4; j++) {
 151            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
 152            int32x4_t vi = vcvtnq_s32_f32(v);
 153            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
 154            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
 155            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
 156            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
 157            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
 158            vi = vcvtnq_s32_f32(v);
 159            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
 160            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
 161            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
 162            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
 163
 164            v = vmulq_n_f32(srcv[1][2 * j], id[1]);
 165            vi = vcvtnq_s32_f32(v);
 166            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
 167            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
 168            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
 169            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
 170            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
 171            vi = vcvtnq_s32_f32(v);
 172            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
 173            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
 174            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
 175            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
 176
 177            v = vmulq_n_f32(srcv[2][2 * j], id[2]);
 178            vi = vcvtnq_s32_f32(v);
 179            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
 180            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
 181            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
 182            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
 183            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
 184            vi = vcvtnq_s32_f32(v);
 185            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
 186            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
 187            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
 188            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
 189
 190            v = vmulq_n_f32(srcv[3][2 * j], id[3]);
 191            vi = vcvtnq_s32_f32(v);
 192            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
 193            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
 194            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
 195            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
 196            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
 197            vi = vcvtnq_s32_f32(v);
 198            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
 199            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
 200            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
 201            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
 202        }
 203    }
 204
 205#else
 206    UNUSED(nb);
 207    UNUSED(y);
 208    ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);
 209#endif
 210}
 211
 212void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 213    const int qk = QK8_0;
 214    const int nb = n / qk;
 215    const int ncols_interleaved = 4;
 216    const int blocklen = 4;
 217
 218    assert (n % qk == 0);
 219    assert (nc % ncols_interleaved == 0);
 220
 221    UNUSED(s);
 222    UNUSED(bs);
 223    UNUSED(vx);
 224    UNUSED(vy);
 225    UNUSED(nr);
 226    UNUSED(nc);
 227    UNUSED(nb);
 228    UNUSED(ncols_interleaved);
 229    UNUSED(blocklen);
 230
 231#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 232    const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
 233
 234    for (int c = 0; c < nc; c += ncols_interleaved) {
 235        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 236        float32x4_t acc = vdupq_n_f32(0);
 237        for (int b = 0; b < nb; b++) {
 238            int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
 239            int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
 240            int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
 241            int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
 242            float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
 243
 244            int8x16_t a0 = vld1q_s8(a_ptr->qs);
 245            int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
 246            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
 247
 248            int32x4_t ret = vdupq_n_s32(0);
 249
 250            ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
 251            ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
 252            ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
 253            ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
 254
 255            ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
 256            ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
 257            ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
 258            ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
 259
 260            acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
 261                            vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
 262            a_ptr++;
 263            b_ptr++;
 264        }
 265        vst1q_f32(s, acc);
 266        s += ncols_interleaved;
 267    }
 268    return;
 269#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 270    ggml_gemv_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 271}
 272
 273void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 274    const int qk = QK8_0;
 275    const int nb = n / qk;
 276    const int ncols_interleaved = 4;
 277    const int blocklen = 8;
 278
 279    assert (n % qk == 0);
 280    assert (nc % ncols_interleaved == 0);
 281
 282    UNUSED(s);
 283    UNUSED(bs);
 284    UNUSED(vx);
 285    UNUSED(vy);
 286    UNUSED(nr);
 287    UNUSED(nc);
 288    UNUSED(nb);
 289    UNUSED(ncols_interleaved);
 290    UNUSED(blocklen);
 291
 292#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 293    const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
 294
 295    for (int c = 0; c < nc; c += ncols_interleaved) {
 296        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 297        float32x4_t acc = vdupq_n_f32(0);
 298        for (int b = 0; b < nb; b++) {
 299            int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
 300            int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
 301            int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
 302            int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
 303            float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
 304
 305            int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
 306            int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
 307            int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
 308            int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
 309            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
 310
 311            int32x4_t ret0 = vdupq_n_s32(0);
 312            int32x4_t ret1 = vdupq_n_s32(0);
 313
 314            ret0 = vdotq_s32(ret0, b0 << 4, a0);
 315            ret1 = vdotq_s32(ret1, b1 << 4, a0);
 316            ret0 = vdotq_s32(ret0, b2 << 4, a1);
 317            ret1 = vdotq_s32(ret1, b3 << 4, a1);
 318
 319            ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
 320            ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
 321            ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
 322            ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
 323
 324            int32x4_t ret = vpaddq_s32(ret0, ret1);
 325
 326            acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
 327                    vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
 328            a_ptr++;
 329            b_ptr++;
 330        }
 331        vst1q_f32(s, acc);
 332        s += ncols_interleaved;
 333    }
 334    return;
 335#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 336    ggml_gemv_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 337}
 338
 339void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 340    const int qk = QK8_0;
 341    const int nb = n / qk;
 342    const int ncols_interleaved = 8;
 343    const int blocklen = 8;
 344
 345    assert (n % qk == 0);
 346    assert (nc % ncols_interleaved == 0);
 347
 348    UNUSED(s);
 349    UNUSED(bs);
 350    UNUSED(vx);
 351    UNUSED(vy);
 352    UNUSED(nr);
 353    UNUSED(nc);
 354    UNUSED(nb);
 355    UNUSED(ncols_interleaved);
 356    UNUSED(blocklen);
 357
 358#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
 359#if defined(__ARM_FEATURE_SVE)
 360    if (ggml_cpu_get_sve_cnt() == QK8_0) {
 361        const void * b_ptr = vx;
 362        const void * a_ptr = vy;
 363        float * res_ptr = s;
 364
 365        __asm__ __volatile__(
 366            "ptrue p0.b\n"
 367            "add %x[b_ptr], %x[b_ptr], #0x10\n"
 368            "1:"  // Column loop
 369            "add x22, %x[a_ptr], #0x2\n"
 370            "mov z31.b, #0x0\n"
 371            "mov x21, %x[nb]\n"
 372            "2:"  // Block loop
 373            "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
 374            "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
 375            "mov z28.s, #0x0\n"
 376            "mov z27.s, #0x0\n"
 377            "ld1rd { z26.d }, p0/Z, [x22]\n"
 378            "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
 379            "sub x20, x22, #0x2\n"
 380            "sub x21, x21, #0x1\n"
 381            "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
 382            "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
 383            "lsl z22.b, z30.b, #0x4\n"
 384            "lsl z16.b, z29.b, #0x4\n"
 385            "and z30.b, z30.b, #0xf0\n"
 386            "and z29.b, z29.b, #0xf0\n"
 387            "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
 388            "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
 389            "lsl z19.b, z25.b, #0x4\n"
 390            "and z25.b, z25.b, #0xf0\n"
 391            "ld1rh { z17.h }, p0/Z, [x20]\n"
 392            "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
 393            "sdot z28.s, z22.b, z26.b\n"
 394            "sdot z27.s, z16.b, z26.b\n"
 395            "lsl z16.b, z24.b, #0x4\n"
 396            "add x22, x22, #0x22\n"
 397            "and z24.b, z24.b, #0xf0\n"
 398            "add %x[b_ptr], %x[b_ptr], #0x90\n"
 399            "fcvt z17.s, p0/m, z17.h\n"
 400            "fcvt z18.s, p0/m, z18.h\n"
 401            "sdot z28.s, z19.b, z23.b\n"
 402            "sdot z27.s, z16.b, z23.b\n"
 403            "fmul z18.s, z18.s, z17.s\n"
 404            "sdot z28.s, z30.b, z21.b\n"
 405            "sdot z27.s, z29.b, z21.b\n"
 406            "sdot z28.s, z25.b, z20.b\n"
 407            "sdot z27.s, z24.b, z20.b\n"
 408            "uzp1 z17.s, z28.s, z27.s\n"
 409            "uzp2 z16.s, z28.s, z27.s\n"
 410            "add z17.s, z17.s, z16.s\n"
 411            "asr z17.s, z17.s, #0x4\n"
 412            "scvtf z17.s, p0/m, z17.s\n"
 413            "fmla z31.s, p0/M, z17.s, z18.s\n"
 414            "cbnz x21, 2b\n"
 415            "sub %x[nc], %x[nc], #0x8\n"
 416            "st1w { z31.s }, p0, [%x[res_ptr]]\n"
 417            "add %x[res_ptr], %x[res_ptr], #0x20\n"
 418            "cbnz %x[nc], 1b\n"
 419            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
 420            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
 421            : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
 422        );
 423        return;
 424    }
 425#endif // #if defined(__ARM_FEATURE_SVE)
 426
 427#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
 428    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 429}
 430
 431void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 432    const int qk = QK8_0;
 433    const int nb = n / qk;
 434    const int ncols_interleaved = 4;
 435    const int blocklen = 4;
 436
 437    assert (n % qk == 0);
 438    assert (nc % ncols_interleaved == 0);
 439
 440    UNUSED(s);
 441    UNUSED(bs);
 442    UNUSED(vx);
 443    UNUSED(vy);
 444    UNUSED(nr);
 445    UNUSED(nc);
 446    UNUSED(nb);
 447    UNUSED(ncols_interleaved);
 448    UNUSED(blocklen);
 449
 450#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 451    const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
 452    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
 453    float * res_ptr = s;
 454
 455    for (int x = 0; x < nc / ncols_interleaved; x++) {
 456        const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
 457
 458        float32x4_t sumf = vdupq_n_f32(0);
 459        for (int l = 0; l < nb; l++) {
 460            uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
 461            uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
 462            uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
 463            uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
 464
 465            int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
 466            int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
 467            int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
 468            int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
 469            int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
 470            int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
 471            int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
 472            int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
 473
 474            int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
 475            int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
 476
 477            int32x4_t sumi = vdupq_n_s32(0);
 478            sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
 479            sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
 480            sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
 481            sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
 482            sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
 483            sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
 484            sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
 485            sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
 486
 487            float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
 488            float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
 489            float32x4_t d = a_d * b_d;
 490
 491            sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
 492        }
 493
 494        vst1q_f32(res_ptr + x * 4, sumf);
 495    }
 496    return;
 497#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
 498    ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 499}
 500
 501void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 502    constexpr int qk = QK_K;
 503    const int     nb = n / qk;
 504
 505    constexpr int ncols_interleaved = 8;
 506    constexpr int blocklen          = 8;
 507
 508    assert(n % qk == 0);
 509    assert(nc % ncols_interleaved == 0);
 510
 511    UNUSED(nb);
 512    UNUSED(ncols_interleaved);
 513    UNUSED(blocklen);
 514
 515#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 516    constexpr int    col_groups = ncols_interleaved / 4; // 0123 and 4567
 517    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
 518
 519    // 1x8 tile = 2 x 4
 520    float32x4_t acc_f32[col_groups];
 521
 522    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
 523
 524    for (int x = 0; x < nc / ncols_interleaved; x++) {
 525        const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
 526
 527        for (int i = 0; i < col_groups; i++) {
 528            acc_f32[i] = vdupq_n_f32(0);
 529        }
 530
 531        for (int b = 0; b < nb; b++) {
 532            float32x4_t q4_d_0        = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));      // d0 d1 d2 d3
 533            float32x4_t q4_d_1        = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));  // d4 d5 d6 d7
 534            float32x4_t q8_d          = vdupq_n_f32(q8_ptr[b].d);
 535            float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
 536            float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
 537            float32x4_t q4_dmin_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));      // dmin 0..3
 538            float32x4_t q4_dmin_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));  // dmin 4..7
 539            float32x4_t sb_min_0123   = vmulq_f32(q4_dmin_0, q8_d);
 540            float32x4_t sb_min_4567   = vmulq_f32(q4_dmin_1, q8_d);
 541
 542            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
 543            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
 544            int32x4_t acc_lo[col_groups];
 545            int32x4_t acc_hi[col_groups];
 546
 547            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
 548            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
 549            int16_t         bsums_arr[8];
 550            vst1q_s16(bsums_arr, bsums);
 551            for (int sb = 0; sb < QK_K / 64; sb++) {
 552                for (int i = 0; i < col_groups; i++) {
 553                    acc_lo[i] = vdupq_n_s32(0);
 554                    acc_hi[i] = vdupq_n_s32(0);
 555                }
 556                // Need scales for the low and high nibbles
 557                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
 558                int16x8_t q4sb_mins[2];
 559                int16x8_t q4sb_scales[2];
 560                for (int i = 0; i < 2; i++) {
 561                    int8_t    aux_q4sb[8];
 562                    const int offset = sb * 24 + i * 12;
 563                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
 564                    q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
 565                }
 566
 567                int8x16_t q8_qs[64 / 16];
 568                for (int i = 0; i < 64 / 16; i++) {
 569                    q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
 570                }
 571
 572                for (int c = 0; c < col_groups; c++) {
 573                    uint8x16_t q4_cols[8];
 574                    for (int i = 0; i < 8; i++) {
 575                        q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
 576                    }
 577
 578                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
 579                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
 580                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
 581                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
 582                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
 583                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
 584                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
 585                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
 586
 587                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
 588                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
 589                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
 590                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
 591                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
 592                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
 593                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
 594                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
 595                }
 596
 597                // Scales
 598                // row c0123 blk0 and blk1
 599                const int16x4_t   sc_0123_lo = vget_low_s16(q4sb_scales[0]);
 600                const int16x4_t   sc_0123_hi = vget_low_s16(q4sb_scales[1]);
 601                const float32x4_t sumf_0123  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
 602                                                                       vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
 603                acc_f32[0]                   = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
 604                // row c4567 blk0 and blk1
 605                const int16x4_t   sc_4567_lo = vget_high_s16(q4sb_scales[0]);
 606                const int16x4_t   sc_4567_hi = vget_high_s16(q4sb_scales[1]);
 607                const float32x4_t sumf_4567  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
 608                                                                       vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
 609                acc_f32[1]                   = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
 610
 611                // Bias Correction
 612                const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
 613                const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
 614
 615                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
 616                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
 617                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
 618                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
 619            }  // for sb
 620
 621            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
 622            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
 623        }  // for b
 624
 625        int base = x * ncols_interleaved;
 626        vst1q_f32(s + base, acc_f32[0]);
 627        vst1q_f32(s + base + 4, acc_f32[1]);
 628    }  // for x
 629    return;
 630#endif  // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 631    ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 632}
 633
 634void ggml_gemv_q4_K_8x8_q8_K(int                        n,
 635                             float * GGML_RESTRICT      s,
 636                             size_t                     bs,
 637                             const void * GGML_RESTRICT vx,
 638                             const void * GGML_RESTRICT vy,
 639                             int                        nr,
 640                             int                        nc) {
 641    constexpr int qk = QK_K;
 642    const int     nb = n / qk;
 643
 644    constexpr int ncols_interleaved = 8;
 645    constexpr int blocklen          = 8;
 646
 647    assert(n % qk == 0);
 648    assert(nc % ncols_interleaved == 0);
 649
 650    UNUSED(nb);
 651    UNUSED(ncols_interleaved);
 652    UNUSED(blocklen);
 653
 654#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 655    constexpr int    col_pairs = ncols_interleaved / 2;
 656    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
 657
 658    // 1x8 tile = 2 x 4
 659    float32x4_t acc_f32[ncols_interleaved / 4];
 660
 661    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
 662
 663    for (int x = 0; x < nc / ncols_interleaved; x++) {
 664        const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
 665
 666        for (int i = 0; i < ncols_interleaved / 4; i++) {
 667            acc_f32[i] = vdupq_n_f32(0);
 668        }
 669
 670        for (int b = 0; b < nb; b++) {
 671            float32x4_t q4_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));      // d0 d1 d2 d3
 672            float32x4_t q4_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));  // d4 d5 d6 d7
 673            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
 674            float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
 675            float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
 676            float32x4_t q4_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));      // dmin 0..3
 677            float32x4_t q4_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));  // dmin 4..7
 678            float32x4_t sb_min_0   = vmulq_f32(q4_dmin_0, q8_d);
 679            float32x4_t sb_min_1   = vmulq_f32(q4_dmin_1, q8_d);
 680
 681            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
 682            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
 683            // 2 sb each iteration
 684            int32x4_t acc_lo[col_pairs];
 685            int32x4_t acc_hi[col_pairs];
 686
 687            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
 688            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
 689            int16_t         bsums_arr[8];
 690            vst1q_s16(bsums_arr, bsums);
 691            for (int sb = 0; sb < QK_K / 64; sb++) {
 692                for (int i = 0; i < col_pairs; i++) {
 693                    acc_lo[i] = vdupq_n_s32(0);
 694                    acc_hi[i] = vdupq_n_s32(0);
 695                }
 696                // Need scales for the low and high nibbles
 697                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
 698                int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
 699                int16x8_t q4sb_scales[2];
 700                for (int i = 0; i < 2; i++) {
 701                    int8_t    aux_q4sb[8];
 702                    const int offset = sb * 24 + i * 12;
 703                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
 704                    q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
 705                }
 706
 707                const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
 708
 709                // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
 710                // but still need the qs to use the low and hi bits from q4
 711                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
 712                int8x16_t      q8_qs[8];
 713                for (int i = 0; i < 8; i++) {
 714                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
 715                }
 716
 717                // Q4s columns iterated in pairs (01, 23, 45, 67)
 718                for (int cp = 0; cp < col_pairs; cp++) {
 719                    uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
 720                    uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
 721                    uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
 722                    uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
 723
 724                    acc_lo[cp] =
 725                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]);  // 0 .. 7
 726                    acc_lo[cp] =
 727                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]);  // 8 ..15
 728                    acc_lo[cp] =
 729                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]);  // 16..23
 730                    acc_lo[cp] =
 731                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]);  // 24..31
 732
 733                    acc_hi[cp] =
 734                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]);  // 32..39
 735                    acc_hi[cp] =
 736                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]);  // 40..47
 737                    acc_hi[cp] =
 738                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]);  // 48..55
 739                    acc_hi[cp] =
 740                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]);  // 56..63
 741                }
 742
 743                // Iterates over a pair of column pairs (4 columns) to use a single 128 register
 744                // p = 0 -> 0123  p2 -> 4567
 745                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
 746                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
 747                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
 748                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;
 749
 750                    // 0123 or 4567
 751                    float32x4_t sumf_0 =
 752                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
 753                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
 754
 755                    float32x4_t sumf_1 =
 756                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
 757                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
 758                }
 759
 760                // Multiply Acc bsum + mins
 761                // Each pair of subblocks share the same bsums
 762                // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
 763                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
 764                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
 765
 766                // cols 0-3 bias
 767                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
 768                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
 769
 770                // cols 4-7 bias
 771                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
 772                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
 773            }  // for sb
 774
 775            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
 776            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
 777        }  // for b
 778
 779        int base = x * ncols_interleaved;
 780        vst1q_f32(s + base, acc_f32[0]);
 781        vst1q_f32(s + base + 4, acc_f32[1]);
 782    }  // for x
 783    return;
 784#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 785    ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 786}
 787
 788void ggml_gemv_q5_K_8x8_q8_K(int                        n,
 789                             float * GGML_RESTRICT      s,
 790                             size_t                     bs,
 791                             const void * GGML_RESTRICT vx,
 792                             const void * GGML_RESTRICT vy,
 793                             int                        nr,
 794                             int                        nc) {
 795    constexpr int qk = QK_K;
 796    const int     nb = n / qk;
 797
 798    constexpr int ncols_interleaved = 8;
 799    constexpr int blocklen          = 8;
 800
 801    assert(n % qk == 0);
 802    assert(nc % ncols_interleaved == 0);
 803
 804    UNUSED(nb);
 805    UNUSED(ncols_interleaved);
 806    UNUSED(blocklen);
 807
 808#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
 809    constexpr int    col_pairs = ncols_interleaved / 2;
 810    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
 811    const uint8x16_t mone      = vdupq_n_u8(1);
 812    const uint8x16_t mtwo      = vdupq_n_u8(2);
 813
 814    // 1x8 tile = 2 x 4
 815    float32x4_t acc_f32[ncols_interleaved / 4];
 816
 817    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
 818
 819    for (int x = 0; x < nc / ncols_interleaved; x++) {
 820        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
 821
 822        for (int i = 0; i < ncols_interleaved / 4; i++) {
 823            acc_f32[i] = vdupq_n_f32(0);
 824        }
 825
 826        for (int b = 0; b < nb; b++) {
 827            float32x4_t q5_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3
 828            float32x4_t q5_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7
 829            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
 830            float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
 831            float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
 832            float32x4_t q5_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3
 833            float32x4_t q5_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7
 834            float32x4_t sb_min_0   = vmulq_f32(q5_dmin_0, q8_d);
 835            float32x4_t sb_min_1   = vmulq_f32(q5_dmin_1, q8_d);
 836
 837            // 2 sb each iteration
 838            int32x4_t acc_lo[col_pairs];
 839            int32x4_t acc_hi[col_pairs];
 840
 841            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
 842            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
 843            int16_t         bsums_arr[8];
 844            vst1q_s16(bsums_arr, bsums);
 845
 846            // Load qh once per block and shift after each subblock
 847            const uint8_t * qh_base = q5_ptr[b].qh;
 848            uint8x16_t      qh[col_pairs][4];
 849            for (int cp = 0; cp < col_pairs; cp++) {
 850                qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
 851                qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
 852                qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
 853                qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
 854            }
 855
 856            for (int sb = 0; sb < QK_K / 64; sb++) {
 857                for (int i = 0; i < col_pairs; i++) {
 858                    acc_lo[i] = vdupq_n_s32(0);
 859                    acc_hi[i] = vdupq_n_s32(0);
 860                }
 861                // Need scales for the low and high nibbles
 862                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
 863                int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
 864                int16x8_t q5sb_scales[2];
 865                for (int i = 0; i < 2; i++) {
 866                    int8_t    aux_q5sb[8];
 867                    const int offset = sb * 24 + i * 12;
 868                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
 869                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
 870                }
 871
 872                const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
 873
 874                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
 875                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
 876                int8x16_t      q8_qs[8];
 877                for (int i = 0; i < 8; i++) {
 878                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
 879                }
 880
 881                // Q5s column pair loop unrolled
 882                {
 883                    // Cols 01
 884                    uint8x16_t qs_0 = vld1q_u8(qs_base);
 885                    uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
 886                    uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
 887                    uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
 888
 889                    uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
 890                    uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
 891                    uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
 892                    uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
 893                    uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
 894                    uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
 895                    uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
 896                    uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
 897
 898                    qh[0][0] = vshrq_n_u8(qh[0][0], 2);
 899                    qh[0][1] = vshrq_n_u8(qh[0][1], 2);
 900                    qh[0][2] = vshrq_n_u8(qh[0][2], 2);
 901                    qh[0][3] = vshrq_n_u8(qh[0][3], 2);
 902
 903                    acc_lo[0] = ggml_vdotq_s32(
 904                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
 905                    acc_lo[0] = ggml_vdotq_s32(
 906                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
 907                    acc_lo[0] = ggml_vdotq_s32(
 908                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
 909                    acc_lo[0] = ggml_vdotq_s32(
 910                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
 911                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
 912                                               q8_qs[4]);
 913                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
 914                                               q8_qs[5]);
 915                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
 916                                               q8_qs[6]);
 917                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
 918                                               q8_qs[7]);
 919
 920                    // Cols 23
 921                    qs_0 = vld1q_u8(qs_base + 16);
 922                    qs_1 = vld1q_u8(qs_base + 80);
 923                    qs_2 = vld1q_u8(qs_base + 144);
 924                    qs_3 = vld1q_u8(qs_base + 208);
 925
 926                    hbit_lo_0 = vandq_u8(qh[1][0], mone);
 927                    hbit_lo_1 = vandq_u8(qh[1][1], mone);
 928                    hbit_lo_2 = vandq_u8(qh[1][2], mone);
 929                    hbit_lo_3 = vandq_u8(qh[1][3], mone);
 930                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
 931                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
 932                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
 933                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
 934
 935                    qh[1][0] = vshrq_n_u8(qh[1][0], 2);
 936                    qh[1][1] = vshrq_n_u8(qh[1][1], 2);
 937                    qh[1][2] = vshrq_n_u8(qh[1][2], 2);
 938                    qh[1][3] = vshrq_n_u8(qh[1][3], 2);
 939
 940                    acc_lo[1] = ggml_vdotq_s32(
 941                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
 942                    acc_lo[1] = ggml_vdotq_s32(
 943                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
 944                    acc_lo[1] = ggml_vdotq_s32(
 945                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
 946                    acc_lo[1] = ggml_vdotq_s32(
 947                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
 948                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
 949                                               q8_qs[4]);
 950                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
 951                                               q8_qs[5]);
 952                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
 953                                               q8_qs[6]);
 954                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
 955                                               q8_qs[7]);
 956
 957                    // Cols 45
 958                    qs_0 = vld1q_u8(qs_base + 32);
 959                    qs_1 = vld1q_u8(qs_base + 96);
 960                    qs_2 = vld1q_u8(qs_base + 160);
 961                    qs_3 = vld1q_u8(qs_base + 224);
 962
 963                    hbit_lo_0 = vandq_u8(qh[2][0], mone);
 964                    hbit_lo_1 = vandq_u8(qh[2][1], mone);
 965                    hbit_lo_2 = vandq_u8(qh[2][2], mone);
 966                    hbit_lo_3 = vandq_u8(qh[2][3], mone);
 967                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
 968                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
 969                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
 970                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
 971
 972                    qh[2][0] = vshrq_n_u8(qh[2][0], 2);
 973                    qh[2][1] = vshrq_n_u8(qh[2][1], 2);
 974                    qh[2][2] = vshrq_n_u8(qh[2][2], 2);
 975                    qh[2][3] = vshrq_n_u8(qh[2][3], 2);
 976
 977                    acc_lo[2] = ggml_vdotq_s32(
 978                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
 979                    acc_lo[2] = ggml_vdotq_s32(
 980                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
 981                    acc_lo[2] = ggml_vdotq_s32(
 982                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
 983                    acc_lo[2] = ggml_vdotq_s32(
 984                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
 985                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
 986                                               q8_qs[4]);
 987                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
 988                                               q8_qs[5]);
 989                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
 990                                               q8_qs[6]);
 991                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
 992                                               q8_qs[7]);
 993
 994                    // Cols 45
 995                    qs_0 = vld1q_u8(qs_base + 48);
 996                    qs_1 = vld1q_u8(qs_base + 112);
 997                    qs_2 = vld1q_u8(qs_base + 176);
 998                    qs_3 = vld1q_u8(qs_base + 240);
 999
1000                    hbit_lo_0 = vandq_u8(qh[3][0], mone);
1001                    hbit_lo_1 = vandq_u8(qh[3][1], mone);
1002                    hbit_lo_2 = vandq_u8(qh[3][2], mone);
1003                    hbit_lo_3 = vandq_u8(qh[3][3], mone);
1004                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
1005                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
1006                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
1007                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
1008
1009                    qh[3][0] = vshrq_n_u8(qh[3][0], 2);
1010                    qh[3][1] = vshrq_n_u8(qh[3][1], 2);
1011                    qh[3][2] = vshrq_n_u8(qh[3][2], 2);
1012                    qh[3][3] = vshrq_n_u8(qh[3][3], 2);
1013
1014                    acc_lo[3] = ggml_vdotq_s32(
1015                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
1016                    acc_lo[3] = ggml_vdotq_s32(
1017                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
1018                    acc_lo[3] = ggml_vdotq_s32(
1019                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
1020                    acc_lo[3] = ggml_vdotq_s32(
1021                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
1022                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
1023                                               q8_qs[4]);
1024                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
1025                                               q8_qs[5]);
1026                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
1027                                               q8_qs[6]);
1028                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
1029                                               q8_qs[7]);
1030                }
1031
1032                // Prepare bsum vectors for bias computation
1033                // Each pair of subblocks share the same bsums
1034                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
1035                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
1036
1037                // Iterates over a pair of column pairs (4 columns) to use a single 128 register
1038                // p = 0 -> 0123  p2 -> 4567
1039                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
1040                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
1041                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
1042                    int16x4_t   group_mins_lo   = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
1043                    int16x4_t   group_mins_hi   = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
1044                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;
1045                    float32x4_t sb_min          = p == 0 ? sb_min_0 : sb_min_1;
1046
1047                    // 0123 or 4567
1048                    float32x4_t sumf_0 =
1049                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
1050                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
1051
1052                    float32x4_t sumf_1 =
1053                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
1054                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
1055
1056                    // FUSED BIAS: Compute and subtract bias immediately
1057                    // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
1058                    int32x4_t bias       = vmull_s16(bsums_vec_lo, group_mins_lo);
1059                    bias                 = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
1060                    float32x4_t bias_f32 = vcvtq_f32_s32(bias);
1061                    acc_f32[i]           = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
1062                }
1063            }  // for sb
1064        }  // for b
1065
1066        int base = x * ncols_interleaved;
1067        vst1q_f32(s + base, acc_f32[0]);
1068        vst1q_f32(s + base + 4, acc_f32[1]);
1069    }  // for x
1070    return;
1071#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1072    ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1073}
1074
1075void ggml_gemv_q6_K_8x4_q8_K(int                        n,
1076                             float * GGML_RESTRICT      s,
1077                             size_t                     bs,
1078                             const void * GGML_RESTRICT vx,
1079                             const void * GGML_RESTRICT vy,
1080                             int                        nr,
1081                             int                        nc) {
1082    constexpr int qk = QK_K;
1083    const int     nb = n / qk;
1084
1085    constexpr int ncols_interleaved = 8;
1086    constexpr int blocklen          = 4;
1087
1088    assert(n % qk == 0);
1089    assert(nc % ncols_interleaved == 0);
1090
1091    UNUSED(nb);
1092    UNUSED(ncols_interleaved);
1093    UNUSED(blocklen);
1094
1095#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1096    constexpr int    col_groups = ncols_interleaved / 4;
1097    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
1098    const uint8x16_t mask_lo    = vdupq_n_u8(0x03);
1099    const uint8x16_t mask_hi    = vdupq_n_u8(0x30);
1100
1101    // 1x8 tile = 2 x 4
1102    float32x4_t acc_f32[2];
1103
1104    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1105
1106    for (int x = 0; x < nc / ncols_interleaved; x++) {
1107        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1108
1109        for (int i = 0; i < col_groups; i++) {
1110            acc_f32[i] = vdupq_n_f32(0);
1111        }
1112
1113        for (int b = 0; b < nb; b++) {
1114            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
1115            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
1116            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
1117            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1118            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1119
1120            int32x4_t acc[col_groups];
1121            for (int i = 0; i < col_groups; i++) {
1122                acc[i] = vdupq_n_s32(0);
1123            }
1124
1125            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1126            // Reused for bias and dequantization later
1127            int16_t q6_scales[16 * 8];
1128            for (int i = 0; i < 16; i++) {
1129                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1130                vst1q_s16(q6_scales + i * 8, scales);
1131            }
1132
1133            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1134            int32x4_t bias_lo = vdupq_n_s32(0);
1135            int32x4_t bias_hi = vdupq_n_s32(0);
1136
1137            // Load bsums in chunks of 4 to process with vectorized operations
1138            for (int i = 0; i < 16; i += 4) {
1139                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
1140                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1141                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1142                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1143                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1144                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1145                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1146                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1147                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1148
1149                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1150                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1151                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1152                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1153                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1154                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1155                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1156                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1157            }
1158            bias_lo = vshlq_n_s32(bias_lo, 5);
1159            bias_hi = vshlq_n_s32(bias_hi, 5);
1160
1161            // Process two 128-value halves per superblock
1162            for (int half = 0; half < 2; half++) {
1163                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1164                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1165
1166                // A subblock (sb) is a set of weights that share the scale
1167                // Since q6_K scales are per 16 elements
1168                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1169                for (int sb = 0; sb < QK_K / 64; sb++) {
1170                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1171                    const int8_t * q8_base_h = q8_base_l + 64;
1172
1173                    // Load and duplicate q8 values (each register covers four interleaved columns of q6)
1174                    int8x16_t q8_l[4];
1175                    int8x16_t q8_h[4];
1176                    for (int i = 0; i < 4; i++) {
1177                        q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
1178                        q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
1179                    }
1180
1181                    const int ql_off_base = sb * QK_K / 2;
1182                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
1183
1184                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1185                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1186                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1187                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1188                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1189
1190                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
1191                    if (sb > 1) {
1192                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1193                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1194                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1195                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1196                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1197                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1198                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1199                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1200                    }
1201
1202                    const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
1203                                                  q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
1204                    const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
1205                                                  q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
1206
1207                    // Process column groups (0-3, 4-7)
1208                    for (int g = 0; g < col_groups; g++) {
1209                        int32x4_t sb_acc_l = vdupq_n_s32(0);
1210                        int32x4_t sb_acc_h = vdupq_n_s32(0);
1211
1212                        for (int chunk = 0; chunk < 4; chunk++) {
1213                            const int idx = chunk * 2 + g;
1214
1215                            const uint8x16_t q6_qs_l = q6_ql[idx];
1216                            const uint8x16_t q6_qs_h = q6_qh[idx];
1217
1218                            // Extract high 2 bits for upper nibble reconstruction
1219                            const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
1220
1221                            // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1222                            const int8x16_t q6_l =
1223                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
1224                            const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
1225
1226                            sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
1227                            sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
1228                        }
1229
1230                        const int scale_idx_l = half * 8 + sb;
1231                        const int scale_idx_h = half * 8 + sb + 4;
1232
1233                        const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
1234                        const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
1235
1236                        acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
1237                        acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
1238                    }
1239                }
1240            }  // for half
1241
1242            // Bias correction
1243            acc[0] = vsubq_s32(acc[0], bias_lo);
1244            acc[1] = vsubq_s32(acc[1], bias_hi);
1245
1246            // Apply superblock scale (no mins for q6_K)
1247            // acc[g] has [c0, c1, c2, c3]
1248            float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
1249            float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
1250
1251            acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
1252            acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
1253        }  // for b
1254
1255        int base = x * ncols_interleaved;
1256        vst1q_f32(s + base, acc_f32[0]);
1257        vst1q_f32(s + base + 4, acc_f32[1]);
1258    }  // for x
1259    return;
1260#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1261    ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1262}
1263
1264void ggml_gemv_q6_K_8x8_q8_K(int                        n,
1265                             float * GGML_RESTRICT      s,
1266                             size_t                     bs,
1267                             const void * GGML_RESTRICT vx,
1268                             const void * GGML_RESTRICT vy,
1269                             int                        nr,
1270                             int                        nc) {
1271    constexpr int qk = QK_K;
1272    const int     nb = n / qk;
1273
1274    constexpr int ncols_interleaved = 8;
1275    constexpr int blocklen          = 8;
1276
1277    assert(n % qk == 0);
1278    assert(nc % ncols_interleaved == 0);
1279
1280    UNUSED(nb);
1281    UNUSED(ncols_interleaved);
1282    UNUSED(blocklen);
1283
1284#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1285    constexpr int    col_pairs = ncols_interleaved / 2;
1286    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
1287    const uint8x16_t mask_lo   = vdupq_n_u8(0x03);
1288    const uint8x16_t mask_hi   = vdupq_n_u8(0x30);
1289
1290    // 1x8 tile = 2 x 4
1291    float32x4_t acc_f32[2];
1292
1293    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
1294
1295    for (int x = 0; x < nc / ncols_interleaved; x++) {
1296        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
1297
1298        acc_f32[0] = vdupq_n_f32(0);
1299        acc_f32[1] = vdupq_n_f32(0);
1300
1301        for (int b = 0; b < nb; b++) {
1302            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
1303            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
1304            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
1305            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
1306            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
1307
1308            int32x2_t acc[col_pairs];
1309            for (int i = 0; i < col_pairs; i++) {
1310                acc[i] = vdup_n_s32(0);
1311            }
1312
1313            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
1314            // Reused for bias and dequantization later
1315            int16_t q6_scales[16 * 8];
1316            for (int i = 0; i < 16; i++) {
1317                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
1318                vst1q_s16(q6_scales + i * 8, scales);
1319            }
1320
1321            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
1322            int32x4_t bias_lo = vdupq_n_s32(0);
1323            int32x4_t bias_hi = vdupq_n_s32(0);
1324
1325            // Load bsums in chunks of 4 to process with vectorized operations
1326            for (int i = 0; i < 16; i += 4) {
1327                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
1328                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
1329                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
1330                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
1331                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
1332                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
1333                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
1334                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
1335                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
1336
1337                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
1338                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
1339                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
1340                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
1341                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
1342                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
1343                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
1344                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
1345            }
1346            bias_lo = vshlq_n_s32(bias_lo, 5);
1347            bias_hi = vshlq_n_s32(bias_hi, 5);
1348
1349            // Process two 128-value halves per superblock
1350            for (int half = 0; half < 2; half++) {
1351                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
1352                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
1353
1354                // A subblock (sb) is a set of weights that share the scale
1355                // Since q6_K scales are per 16 elements
1356                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
1357                for (int sb = 0; sb < QK_K / 64; sb++) {
1358                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
1359                    const int8_t * q8_base_h = q8_base_l + 64;
1360
1361                    // Load and duplicate q8 values (each register covers two interleaved columns of q6)
1362                    int8x16_t q8_l[2];
1363                    int8x16_t q8_h[2];
1364                    for (int i = 0; i < 2; i++) {
1365                        q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
1366                        q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
1367                    }
1368
1369                    const int ql_off_base = sb * QK_K / 2;
1370                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
1371
1372                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
1373                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
1374                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
1375                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
1376                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
1377
1378                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
1379                    if (sb > 1) {
1380                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
1381                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
1382                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
1383                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
1384                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
1385                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
1386                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
1387                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
1388                    }
1389
1390                    // Process column pairs (0-1, 2-3, 4-5, 6-7)
1391                    for (int cp = 0; cp < col_pairs; cp++) {
1392                        const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
1393                        const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
1394                        const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
1395                        const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
1396
1397                        // Extract high 2 bits for upper nibble reconstruction
1398                        const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
1399                        const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
1400
1401                        // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
1402                        const int8x16_t q6_l0 = vreinterpretq_s8_u8(
1403                            vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
1404                        const int8x16_t q6_l1 = vreinterpretq_s8_u8(
1405                            vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
1406                        const int8x16_t q6_h0 =
1407                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
1408                        const int8x16_t q6_h1 =
1409                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
1410
1411                        int32x4_t sb_acc_l = vdupq_n_s32(0);
1412                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
1413                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
1414
1415                        int32x4_t sb_acc_h = vdupq_n_s32(0);
1416                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
1417                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
1418
1419                        // Pairwise add to get per-column sums: [col0, col1]
1420                        int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
1421                        int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
1422
1423                        const int scale_idx_l = half * 8 + sb;
1424                        const int scale_idx_h = half * 8 + sb + 4;
1425
1426                        // Access scales using array indexing (scales are interleaved by column)
1427                        const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
1428                                                        (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
1429                        const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
1430                                                        (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
1431
1432                        // Accumulate scaled results
1433                        acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
1434                        acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
1435                    }
1436                }
1437            }  // for half
1438
1439            // Bias correction
1440            acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
1441            acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
1442            acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
1443            acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
1444
1445            // Apply superblock scale (no mins for q6_K)
1446            // acc[cp] has [c0, c1]
1447            float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
1448            float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
1449            float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
1450            float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
1451
1452            acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
1453            acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
1454        }  // for b
1455
1456        int base = x * ncols_interleaved;
1457        vst1q_f32(s + base, acc_f32[0]);
1458        vst1q_f32(s + base + 4, acc_f32[1]);
1459    }  // for x
1460    return;
1461#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1462    ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1463}
1464
1465void ggml_gemv_q8_0_4x4_q8_0(int                        n,
1466                             float * GGML_RESTRICT      s,
1467                             size_t                     bs,
1468                             const void * GGML_RESTRICT vx,
1469                             const void * GGML_RESTRICT vy,
1470                             int                        nr,
1471                             int                        nc) {
1472    const int qk                = QK8_0;
1473    const int nb                = n / qk;
1474    const int ncols_interleaved = 4;
1475    const int blocklen          = 4;
1476
1477    assert(n % qk == 0);
1478    assert(nc % ncols_interleaved == 0);
1479
1480    UNUSED(nb);
1481    UNUSED(ncols_interleaved);
1482    UNUSED(blocklen);
1483
1484#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1485    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1486
1487    for (int c = 0; c < nc; c += ncols_interleaved) {
1488        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1489        float32x4_t        acc   = vdupq_n_f32(0);
1490        for (int b = 0; b < nb; b++) {
1491            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1492            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1493            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);
1494
1495            int8x16x2_t a  = vld1q_s8_x2(a_ptr->qs);
1496            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1497
1498            int32x4_t ret = vdupq_n_s32(0);
1499
1500            ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
1501            ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
1502            ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
1503            ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
1504
1505            ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
1506            ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
1507            ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
1508            ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
1509
1510            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1511            a_ptr++;
1512            b_ptr++;
1513        }
1514        vst1q_f32(s, acc);
1515        s += ncols_interleaved;
1516    }
1517    return;
1518
1519#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1520    ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1521}
1522
1523void ggml_gemv_q8_0_4x8_q8_0(int                        n,
1524                             float * GGML_RESTRICT      s,
1525                             size_t                     bs,
1526                             const void * GGML_RESTRICT vx,
1527                             const void * GGML_RESTRICT vy,
1528                             int                        nr,
1529                             int                        nc) {
1530    const int qk                = QK8_0;
1531    const int nb                = n / qk;
1532    const int ncols_interleaved = 4;
1533    const int blocklen          = 8;
1534
1535    assert(n % qk == 0);
1536    assert(nc % ncols_interleaved == 0);
1537
1538    UNUSED(nb);
1539    UNUSED(ncols_interleaved);
1540    UNUSED(blocklen);
1541
1542#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1543    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
1544
1545    for (int c = 0; c < nc; c += ncols_interleaved) {
1546        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1547        float32x4_t        acc   = vdupq_n_f32(0);
1548
1549        for (int b = 0; b < nb; b++) {
1550            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);
1551            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
1552            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);
1553
1554            int8x8x4_t  a_chunks = vld1_s8_x4(a_ptr->qs);
1555            int8x16_t   a0       = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
1556            int8x16_t   a1       = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
1557            int8x16_t   a2       = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
1558            int8x16_t   a3       = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
1559            float16x4_t ad       = vld1_dup_f16((const __fp16 *) &a_ptr->d);
1560
1561            int32x4_t ret0 = vdupq_n_s32(0);
1562            int32x4_t ret1 = vdupq_n_s32(0);
1563
1564            // 0..7
1565            ret0 = vdotq_s32(ret0, b_low.val[0], a0);
1566            ret1 = vdotq_s32(ret1, b_low.val[1], a0);
1567            // 8..15
1568            ret0 = vdotq_s32(ret0, b_low.val[2], a1);
1569            ret1 = vdotq_s32(ret1, b_low.val[3], a1);
1570            // 16..23
1571            ret0 = vdotq_s32(ret0, b_high.val[0], a2);
1572            ret1 = vdotq_s32(ret1, b_high.val[1], a2);
1573            // 24..31
1574            ret0 = vdotq_s32(ret0, b_high.val[2], a3);
1575            ret1 = vdotq_s32(ret1, b_high.val[3], a3);
1576
1577            int32x4_t ret = vpaddq_s32(ret0, ret1);
1578
1579            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
1580            a_ptr++;
1581            b_ptr++;
1582        }
1583        vst1q_f32(s, acc);
1584        s += ncols_interleaved;
1585    }
1586    return;
1587
1588#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1589    ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1590}
1591
1592void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1593    const int qk = QK8_0;
1594    const int nb = n / qk;
1595    const int ncols_interleaved = 4;
1596    const int blocklen = 4;
1597
1598    assert (n % qk == 0);
1599    assert (nr % 4 == 0);
1600    assert (nc % ncols_interleaved == 0);
1601
1602    UNUSED(s);
1603    UNUSED(bs);
1604    UNUSED(vx);
1605    UNUSED(vy);
1606    UNUSED(nr);
1607    UNUSED(nc);
1608    UNUSED(nb);
1609    UNUSED(ncols_interleaved);
1610    UNUSED(blocklen);
1611
1612#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1613    const void * b_ptr = vx;
1614    const void * a_ptr = vy;
1615    float * res_ptr = s;
1616    size_t res_stride = bs * sizeof(float);
1617
1618    __asm__ __volatile__(
1619        "mov x10, %x[nr]\n"
1620        "mov x9, #0x88\n"
1621        "cmp x10, #0x10\n"
1622        "mul x9, %x[nb], x9\n"
1623        "blt 4f\n"
1624        "1:"  // Row loop
1625        "add x28, %x[b_ptr], #0x8\n"
1626        "mov x27, %x[nc]\n"
1627        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
1628        "2:"  // Column loop
1629        "add x25, %x[a_ptr], #0x8\n"
1630        "movi v15.16b, #0x0\n"
1631        "movi v19.16b, #0x0\n"
1632        "mov x24, %x[nb]\n"
1633        "add x23, x25, x9\n"
1634        "movi v18.16b, #0x0\n"
1635        "movi v14.16b, #0x0\n"
1636        "add x22, x23, x9\n"
1637        "movi v11.16b, #0x0\n"
1638        "movi v13.16b, #0x0\n"
1639        "add x21, x22, x9\n"
1640        "movi v23.16b, #0x0\n"
1641        "movi v16.16b, #0x0\n"
1642        "movi v25.16b, #0x0\n"
1643        "movi v7.16b, #0x0\n"
1644        "movi v0.16b, #0x0\n"
1645        "movi v4.16b, #0x0\n"
1646        "movi v5.16b, #0x0\n"
1647        "movi v21.16b, #0x0\n"
1648        "movi v8.16b, #0x0\n"
1649        "movi v1.16b, #0x0\n"
1650        "3:"  // Block loop
1651        "ldr q3, [x28, #0x0]\n"
1652        "ldr q31, [x25, #0x0]\n"
1653        "movi v28.16b, #0x4\n"
1654        "movi v10.4s, #0x0\n"
1655        "ldr q22, [x28, #0x10]\n"
1656        "ldr q6, [x25, #0x10]\n"
1657        "movi v29.4s, #0x0\n"
1658        "movi v9.4s, #0x0\n"
1659        "ldr q27, [x28, #0x20]\n"
1660        "ldr q30, [x28, #0x30]\n"
1661        "movi v20.4s, #0x0\n"
1662        "movi v24.16b, #0xf0\n"
1663        "ldr d2, [x25, #-0x8]\n"
1664        "ldr d26, [x23, #-0x8]\n"
1665        "sshl v12.16b, v3.16b, v28.16b\n"
1666        "sub x20, x28, #0x8\n"
1667        "ldr d17, [x20, #0x0]\n"
1668        "and v3.16b, v3.16b, v24.16b\n"
1669        "subs x24, x24, #0x1\n"
1670        "add x28, x28, #0x48\n"
1671        ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
1672        ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
1673        ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
1674        ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
1675        "sshl v31.16b, v22.16b, v28.16b\n"
1676        "and v22.16b, v22.16b, v24.16b\n"
1677        "fcvtl v17.4s, v17.4h\n"
1678        "fcvtl v2.4s, v2.4h\n"
1679        "fcvtl v26.4s, v26.4h\n"
1680        ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
1681        ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
1682        ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
1683        ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
1684        "sshl v6.16b, v27.16b, v28.16b\n"
1685        "sshl v28.16b, v30.16b, v28.16b\n"
1686        "and v27.16b, v27.16b, v24.16b\n"
1687        "and v30.16b, v30.16b, v24.16b\n"
1688        "ldr q24, [x25, #0x20]\n"
1689        ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
1690        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
1691        ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
1692        ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
1693        "ldr q24, [x25, #0x30]\n"
1694        ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
1695        ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
1696        ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
1697        ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
1698        "ldr q24, [x25, #0x40]\n"
1699        ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
1700        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
1701        ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
1702        ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
1703        "ldr q24, [x25, #0x50]\n"
1704        ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
1705        ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
1706        ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
1707        ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
1708        "ldr q24, [x25, #0x60]\n"
1709        ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
1710        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
1711        ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
1712        ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
1713        "ldr q24, [x25, #0x70]\n"
1714        "add x25, x25, #0x88\n"
1715        ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
1716        ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
1717        ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
1718        ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
1719        "fmul v24.4s, v17.4s, v2.s[0]\n"
1720        "scvtf v10.4s, v10.4s, #0x4\n"
1721        "scvtf v29.4s, v29.4s, #0x4\n"
1722        "scvtf v9.4s, v9.4s, #0x4\n"
1723        "scvtf v20.4s, v20.4s, #0x4\n"
1724        "fmla v15.4s, v10.4s, v24.4s\n"
1725        "ldr q24, [x23, #0x0]\n"
1726        "fmul v10.4s, v17.4s, v2.s[1]\n"
1727        "fmla v19.4s, v29.4s, v10.4s\n"
1728        "ldr q10, [x23, #0x10]\n"
1729        "fmul v29.4s, v17.4s, v2.s[2]\n"
1730        "fmul v2.4s, v17.4s, v2.s[3]\n"
1731        "fmla v18.4s, v9.4s, v29.4s\n"
1732        "movi v9.4s, #0x0\n"
1733        "movi v29.4s, #0x0\n"
1734        ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
1735        ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
1736        "fmla v14.4s, v20.4s, v2.4s\n"
1737        "movi v20.4s, #0x0\n"
1738        "movi v2.4s, #0x0\n"
1739        ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
1740        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
1741        "ldr q24, [x23, #0x20]\n"
1742        ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
1743        ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
1744        ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
1745        ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
1746        "ldr q10, [x23, #0x30]\n"
1747        ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
1748        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
1749        ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
1750        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
1751        "ldr q24, [x23, #0x40]\n"
1752        ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
1753        ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
1754        ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
1755        ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
1756        "ldr q10, [x23, #0x50]\n"
1757        ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
1758        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
1759        ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
1760        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
1761        "ldr q24, [x23, #0x60]\n"
1762        ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
1763        ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
1764        ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
1765        ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
1766        "ldr q10, [x23, #0x70]\n"
1767        "add x23, x23, #0x88\n"
1768        ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
1769        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
1770        ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
1771        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
1772        "ldr q24, [x22, #0x0]\n"
1773        ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
1774        ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
1775        ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
1776        ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
1777        "fmul v10.4s, v17.4s, v26.s[0]\n"
1778        "scvtf v9.4s, v9.4s, #0x4\n"
1779        "scvtf v29.4s, v29.4s, #0x4\n"
1780        "scvtf v20.4s, v20.4s, #0x4\n"
1781        "scvtf v2.4s, v2.4s, #0x4\n"
1782        "fmla v11.4s, v9.4s, v10.4s\n"
1783        "ldr q9, [x22, #0x10]\n"
1784        "fmul v10.4s, v17.4s, v26.s[1]\n"
1785        "fmla v13.4s, v29.4s, v10.4s\n"
1786        "ldr d29, [x22, #-0x8]\n"
1787        "fmul v10.4s, v17.4s, v26.s[2]\n"
1788        "fmul v26.4s, v17.4s, v26.s[3]\n"
1789        "fcvtl v29.4s, v29.4h\n"
1790        "fmla v23.4s, v20.4s, v10.4s\n"
1791        "movi v20.4s, #0x0\n"
1792        "movi v10.4s, #0x0\n"
1793        "fmla v16.4s, v2.4s, v26.4s\n"
1794        "movi v26.4s, #0x0\n"
1795        "movi v2.4s, #0x0\n"
1796        ".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\n"
1797        ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
1798        ".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\n"
1799        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
1800        "ldr q24, [x22, #0x20]\n"
1801        ".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\n"
1802        ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
1803        ".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\n"
1804        ".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\n"
1805        "ldr q9, [x22, #0x30]\n"
1806        ".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\n"
1807        ".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\n"
1808        ".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\n"
1809        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
1810        "ldr q24, [x22, #0x40]\n"
1811        ".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\n"
1812        ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
1813        ".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\n"
1814        ".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\n"
1815        "ldr q9, [x22, #0x50]\n"
1816        ".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\n"
1817        ".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\n"
1818        ".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\n"
1819        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
1820        "ldr q24, [x22, #0x60]\n"
1821        ".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\n"
1822        ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
1823        ".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\n"
1824        ".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\n"
1825        "ldr q9, [x22, #0x70]\n"
1826        "add x22, x22, #0x88\n"
1827        ".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\n"
1828        ".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\n"
1829        ".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\n"
1830        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
1831        "ldr q24, [x21, #0x0]\n"
1832        ".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\n"
1833        ".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\n"
1834        ".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\n"
1835        ".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\n"
1836        "fmul v9.4s, v17.4s, v29.s[0]\n"
1837        "scvtf v20.4s, v20.4s, #0x4\n"
1838        "scvtf v10.4s, v10.4s, #0x4\n"
1839        "scvtf v26.4s, v26.4s, #0x4\n"
1840        "scvtf v2.4s, v2.4s, #0x4\n"
1841        "fmla v25.4s, v20.4s, v9.4s\n"
1842        "ldr q9, [x21, #0x10]\n"
1843        "fmul v20.4s, v17.4s, v29.s[1]\n"
1844        "fmla v7.4s, v10.4s, v20.4s\n"
1845        "ldr d20, [x21, #-0x8]\n"
1846        "fmul v10.4s, v17.4s, v29.s[2]\n"
1847        "fmul v29.4s, v17.4s, v29.s[3]\n"
1848        "fcvtl v20.4s, v20.4h\n"
1849        "fmla v0.4s, v26.4s, v10.4s\n"
1850        "movi v26.4s, #0x0\n"
1851        "movi v10.4s, #0x0\n"
1852        "fmla v4.4s, v2.4s, v29.4s\n"
1853        "movi v2.4s, #0x0\n"
1854        "movi v29.4s, #0x0\n"
1855        ".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\n"
1856        ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
1857        ".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\n"
1858        ".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\n"
1859        "ldr q12, [x21, #0x20]\n"
1860        "fmul v24.4s, v17.4s, v20.s[0]\n"
1861        ".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\n"
1862        ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
1863        ".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\n"
1864        ".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\n"
1865        "ldr q9, [x21, #0x30]\n"
1866        "fmul v31.4s, v17.4s, v20.s[1]\n"
1867        ".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\n"
1868        ".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\n"
1869        ".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\n"
1870        ".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\n"
1871        "ldr q12, [x21, #0x40]\n"
1872        "fmul v6.4s, v17.4s, v20.s[2]\n"
1873        "fmul v20.4s, v17.4s, v20.s[3]\n"
1874        ".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\n"
1875        ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
1876        ".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\n"
1877        ".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\n"
1878        "ldr q9, [x21, #0x50]\n"
1879        ".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\n"
1880        ".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\n"
1881        ".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\n"
1882        ".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\n"
1883        "ldr q12, [x21, #0x60]\n"
1884        ".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\n"
1885        ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
1886        ".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\n"
1887        ".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\n"
1888        "ldr q17, [x21, #0x70]\n"
1889        "add x21, x21, #0x88\n"
1890        ".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\n"
1891        ".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\n"
1892        ".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\n"
1893        ".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\n"
1894        ".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\n"
1895        ".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\n"
1896        ".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\n"
1897        ".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\n"
1898        "scvtf v26.4s, v26.4s, #0x4\n"
1899        "scvtf v10.4s, v10.4s, #0x4\n"
1900        "fmla v5.4s, v26.4s, v24.4s\n"
1901        "scvtf v2.4s, v2.4s, #0x4\n"
1902        "scvtf v29.4s, v29.4s, #0x4\n"
1903        "fmla v21.4s, v10.4s, v31.4s\n"
1904        "fmla v8.4s, v2.4s, v6.4s\n"
1905        "fmla v1.4s, v29.4s, v20.4s\n"
1906        "bgt 3b\n"
1907        "mov x20, %x[res_ptr]\n"
1908        "subs x27, x27, #0x4\n"
1909        "add %x[res_ptr], %x[res_ptr], #0x10\n"
1910        "str q15, [x20, #0x0]\n"
1911        "add x20, x20, %x[res_stride]\n"
1912        "str q19, [x20, #0x0]\n"
1913        "add x20, x20, %x[res_stride]\n"
1914        "str q18, [x20, #0x0]\n"
1915        "add x20, x20, %x[res_stride]\n"
1916        "str q14, [x20, #0x0]\n"
1917        "add x20, x20, %x[res_stride]\n"
1918        "str q11, [x20, #0x0]\n"
1919        "add x20, x20, %x[res_stride]\n"
1920        "str q13, [x20, #0x0]\n"
1921        "add x20, x20, %x[res_stride]\n"
1922        "str q23, [x20, #0x0]\n"
1923        "add x20, x20, %x[res_stride]\n"
1924        "str q16, [x20, #0x0]\n"
1925        "add x20, x20, %x[res_stride]\n"
1926        "str q25, [x20, #0x0]\n"
1927        "add x20, x20, %x[res_stride]\n"
1928        "str q7, [x20, #0x0]\n"
1929        "add x20, x20, %x[res_stride]\n"
1930        "str q0, [x20, #0x0]\n"
1931        "add x20, x20, %x[res_stride]\n"
1932        "str q4, [x20, #0x0]\n"
1933        "add x20, x20, %x[res_stride]\n"
1934        "str q5, [x20, #0x0]\n"
1935        "add x20, x20, %x[res_stride]\n"
1936        "str q21, [x20, #0x0]\n"
1937        "add x20, x20, %x[res_stride]\n"
1938        "str q8, [x20, #0x0]\n"
1939        "add x20, x20, %x[res_stride]\n"
1940        "str q1, [x20, #0x0]\n"
1941        "bne 2b\n"
1942        "mov x20, #0x4\n"
1943        "sub x10, x10, #0x10\n"
1944        "cmp x10, #0x10\n"
1945        "mov %x[res_ptr], x26\n"
1946        "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
1947        "bge 1b\n"
1948        "4:"  // Row loop skip
1949        "cbz x10, 9f\n"
1950        "5:"  // Row tail: Row loop
1951        "add x24, %x[b_ptr], #0x8\n"
1952        "mov x23, %x[nc]\n"
1953        "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
1954        "6:"  // Row tail: Column loop
1955        "movi v15.16b, #0x0\n"
1956        "movi v19.16b, #0x0\n"
1957        "add x25, %x[a_ptr], #0x8\n"
1958        "mov x21, %x[nb]\n"
1959        "movi v18.16b, #0x0\n"
1960        "movi v14.16b, #0x0\n"
1961        "7:"  // Row tail: Block loop
1962        "ldr q7, [x24, #0x0]\n"
1963        "ldr q5, [x25, #0x0]\n"
1964        "movi v9.16b, #0x4\n"
1965        "movi v4.4s, #0x0\n"
1966        "ldr q3, [x24, #0x10]\n"
1967        "ldr q2, [x25, #0x10]\n"
1968        "movi v1.4s, #0x0\n"
1969        "movi v0.4s, #0x0\n"
1970        "ldr q13, [x24, #0x20]\n"
1971        "ldr q31, [x25, #0x20]\n"
1972        "movi v30.4s, #0x0\n"
1973        "movi v29.16b, #0xf0\n"
1974        "ldr q28, [x24, #0x30]\n"
1975        "ldr q27, [x25, #0x30]\n"
1976        "sshl v20.16b, v7.16b, v9.16b\n"
1977        "sub x20, x24, #0x8\n"
1978        "ldr q26, [x25, #0x40]\n"
1979        "ldr q25, [x25, #0x50]\n"
1980        "sshl v17.16b, v3.16b, v9.16b\n"
1981        "and v7.16b, v7.16b, v29.16b\n"
1982        "ldr q24, [x25, #0x60]\n"
1983        "ldr q16, [x25, #0x70]\n"
1984        "sshl v22.16b, v13.16b, v9.16b\n"
1985        "and v3.16b, v3.16b, v29.16b\n"
1986        "ldr d21, [x20, #0x0]\n"
1987        "ldr d12, [x25, #-0x8]\n"
1988        ".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\n"
1989        ".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\n"
1990        ".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\n"
1991        ".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\n"
1992        "sshl v9.16b, v28.16b, v9.16b\n"
1993        "subs x21, x21, #0x1\n"
1994        "and v13.16b, v13.16b, v29.16b\n"
1995        "and v28.16b, v28.16b, v29.16b\n"
1996        "add x25, x25, #0x88\n"
1997        "add x24, x24, #0x48\n"
1998        "fcvtl v21.4s, v21.4h\n"
1999        "fcvtl v12.4s, v12.4h\n"
2000        ".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\n"
2001        ".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\n"
2002        ".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\n"
2003        ".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\n"
2004        "fmul v11.4s, v21.4s, v12.s[0]\n"
2005        "fmul v23.4s, v21.4s, v12.s[1]\n"
2006        "fmul v17.4s, v21.4s, v12.s[2]\n"
2007        ".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\n"
2008        "fmul v6.4s, v21.4s, v12.s[3]\n"
2009        ".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\n"
2010        ".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\n"
2011        ".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\n"
2012        ".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\n"
2013        ".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\n"
2014        ".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\n"
2015        ".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\n"
2016        ".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\n"
2017        ".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\n"
2018        ".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\n"
2019        ".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\n"
2020        ".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\n"
2021        ".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\n"
2022        ".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\n"
2023        ".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\n"
2024        ".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\n"
2025        ".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\n"
2026        ".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\n"
2027        ".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\n"
2028        ".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\n"
2029        ".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\n"
2030        ".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\n"
2031        ".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\n"
2032        "scvtf v4.4s, v4.4s, #0x4\n"
2033        "scvtf v1.4s, v1.4s, #0x4\n"
2034        "scvtf v0.4s, v0.4s, #0x4\n"
2035        "fmla v15.4s, v4.4s, v11.4s\n"
2036        "scvtf v30.4s, v30.4s, #0x4\n"
2037        "fmla v19.4s, v1.4s, v23.4s\n"
2038        "fmla v18.4s, v0.4s, v17.4s\n"
2039        "fmla v14.4s, v30.4s, v6.4s\n"
2040        "bgt 7b\n"
2041        "mov x20, %x[res_ptr]\n"
2042        "cmp x10, #0x1\n"
2043        "str q15, [x20, #0x0]\n"
2044        "add x20, x20, %x[res_stride]\n"
2045        "ble 8f\n"
2046        "cmp x10, #0x2\n"
2047        "str q19, [x20, #0x0]\n"
2048        "add x20, x20, %x[res_stride]\n"
2049        "ble 8f\n"
2050        "cmp x10, #0x3\n"
2051        "str q18, [x20, #0x0]\n"
2052        "add x20, x20, %x[res_stride]\n"
2053        "ble 8f\n"
2054        "str q14, [x20, #0x0]\n"
2055        "8:"  // Row tail: Accumulator store skip
2056        "subs x23, x23, #0x4\n"
2057        "add %x[res_ptr], %x[res_ptr], #0x10\n"
2058        "bne 6b\n"
2059        "subs x10, x10, #0x4\n"
2060        "add %x[a_ptr], %x[a_ptr], x9\n"
2061        "mov %x[res_ptr], x22\n"
2062        "bgt 5b\n"
2063        "9:"  // Row tail: Row loop skip
2064        : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2065        : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2066        : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
2067    );
2068    return;
2069#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
2070    ggml_gemm_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2071}
2072
2073void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2074    const int qk = QK8_0;
2075    const int nb = n / qk;
2076    const int ncols_interleaved = 4;
2077    const int blocklen = 8;
2078
2079    assert (n % qk == 0);
2080    assert (nr % 4 == 0);
2081    assert (nc % ncols_interleaved == 0);
2082
2083    UNUSED(s);
2084    UNUSED(bs);
2085    UNUSED(vx);
2086    UNUSED(vy);
2087    UNUSED(nr);
2088    UNUSED(nc);
2089    UNUSED(nb);
2090    UNUSED(ncols_interleaved);
2091    UNUSED(blocklen);
2092
2093#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2094    const void * b_ptr = vx;
2095    const void * a_ptr = vy;
2096    float * res_ptr = s;
2097    size_t res_stride = bs * sizeof(float);
2098
2099    __asm__ __volatile__(
2100        "mov x10, %x[nr]\n"
2101        "mov x9, #0x88\n"
2102        "cmp x10, #0x10\n"
2103        "mul x9, %x[nb], x9\n"
2104        "blt 4f\n"
2105        "1:"  // Row loop
2106        "add x28, %x[b_ptr], #0x8\n"
2107        "mov x27, %x[nc]\n"
2108        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
2109        "2:"  // Column loop
2110        "add x25, %x[a_ptr], #0x8\n"
2111        "movi v2.16b, #0x0\n"
2112        "movi v10.16b, #0x0\n"
2113        "mov x24, %x[nb]\n"
2114        "add x23, x25, x9\n"
2115        "movi v12.16b, #0x0\n"
2116        "movi v28.16b, #0x0\n"
2117        "add x22, x23, x9\n"
2118        "movi v11.16b, #0x0\n"
2119        "movi v13.16b, #0x0\n"
2120        "add x21, x22, x9\n"
2121        "movi v22.16b, #0x0\n"
2122        "movi v23.16b, #0x0\n"
2123        "movi v25.16b, #0x0\n"
2124        "movi v5.16b, #0x0\n"
2125        "movi v7.16b, #0x0\n"
2126        "movi v4.16b, #0x0\n"
2127        "movi v6.16b, #0x0\n"
2128        "movi v30.16b, #0x0\n"
2129        "movi v24.16b, #0x0\n"
2130        "movi v14.16b, #0x0\n"
2131        "3:"  // Block loop
2132        "ldr q21, [x28, #0x0]\n"
2133        "ldr q16, [x28, #0x10]\n"
2134        "movi v1.16b, #0x4\n"
2135        "movi v19.4s, #0x0\n"
2136        "ldr q27, [x25, #0x0]\n"
2137        "ldr q15, [x25, #0x10]\n"
2138        "movi v26.4s, #0x0\n"
2139        "movi v18.4s, #0x0\n"
2140        "ldr q29, [x28, #0x20]\n"
2141        "ldr q3, [x28, #0x30]\n"
2142        "movi v17.4s, #0x0\n"
2143        "movi v0.16b, #0xf0\n"
2144        "ldr d20, [x25, #-0x8]\n"
2145        "ldr d9, [x23, #-0x8]\n"
2146        "sshl v8.16b, v21.16b, v1.16b\n"
2147        "sshl v31.16b, v16.16b, v1.16b\n"
2148        "and v21.16b, v21.16b, v0.16b\n"
2149        "and v16.16b, v16.16b, v0.16b\n"
2150        "sub x20, x28, #0x8\n"
2151        "subs x24, x24, #0x1\n"
2152        "add x28, x28, #0x48\n"
2153        ".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\n"
2154        ".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\n"
2155        "ldr q27, [x25, #0x20]\n"
2156        ".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\n"
2157        ".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\n"
2158        "sshl v15.16b, v29.16b, v1.16b\n"
2159        "sshl v1.16b, v3.16b, v1.16b\n"
2160        "and v29.16b, v29.16b, v0.16b\n"
2161        "and v3.16b, v3.16b, v0.16b\n"
2162        "ldr q0, [x25, #0x30]\n"
2163        "fcvtl v20.4s, v20.4h\n"
2164        ".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\n"
2165        "fcvtl v9.4s, v9.4h\n"
2166        ".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\n"
2167        "ldr q27, [x25, #0x40]\n"
2168        ".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\n"
2169        ".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\n"
2170        "ldr q0, [x25, #0x50]\n"
2171        ".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\n"
2172        ".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\n"
2173        "ldr q27, [x25, #0x60]\n"
2174        ".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\n"
2175        ".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\n"
2176        "ldr q0, [x25, #0x70]\n"
2177        "add x25, x25, #0x88\n"
2178        ".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\n"
2179        ".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\n"
2180        "ldr d27, [x20, #0x0]\n"
2181        ".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\n"
2182        ".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\n"
2183        "fcvtl v27.4s, v27.4h\n"
2184        "uzp1 v0.2d, v19.2d, v26.2d\n"
2185        "uzp2 v26.2d, v19.2d, v26.2d\n"
2186        "fmul v19.4s, v27.4s, v20.s[0]\n"
2187        "scvtf v0.4s, v0.4s, #0x4\n"
2188        "scvtf v26.4s, v26.4s, #0x4\n"
2189        "fmla v2.4s, v0.4s, v19.4s\n"
2190        "ldr q19, [x23, #0x0]\n"
2191        "uzp1 v0.2d, v18.2d, v17.2d\n"
2192        "uzp2 v18.2d, v18.2d, v17.2d\n"
2193        "fmul v17.4s, v27.4s, v20.s[1]\n"
2194        "scvtf v0.4s, v0.4s, #0x4\n"
2195        "scvtf v18.4s, v18.4s, #0x4\n"
2196        "fmla v10.4s, v26.4s, v17.4s\n"
2197        "ldr q17, [x23, #0x10]\n"
2198        "fmul v26.4s, v27.4s, v20.s[2]\n"
2199        "fmul v20.4s, v27.4s, v20.s[3]\n"
2200        "fmla v12.4s, v0.4s, v26.4s\n"
2201        "ldr d0, [x22, #-0x8]\n"
2202        "ldr d26, [x21, #-0x8]\n"
2203        "fcvtl v0.4s, v0.4h\n"
2204        "fmla v28.4s, v18.4s, v20.4s\n"
2205        "movi v20.4s, #0x0\n"
2206        "movi v18.4s, #0x0\n"
2207        ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
2208        ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
2209        "ldr q19, [x23, #0x20]\n"
2210        "fcvtl v26.4s, v26.4h\n"
2211        ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
2212        ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
2213        "ldr q19, [x23, #0x40]\n"
2214        ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
2215        ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
2216        "ldr q19, [x23, #0x60]\n"
2217        ".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\n"
2218        ".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\n"
2219        "uzp1 v19.2d, v20.2d, v18.2d\n"
2220        "scvtf v19.4s, v19.4s, #0x4\n"
2221        "uzp2 v20.2d, v20.2d, v18.2d\n"
2222        "fmul v18.4s, v27.4s, v9.s[0]\n"
2223        "scvtf v20.4s, v20.4s, #0x4\n"
2224        "fmla v11.4s, v19.4s, v18.4s\n"
2225        "ldr q18, [x22, #0x0]\n"
2226        "fmul v19.4s, v27.4s, v9.s[1]\n"
2227        "fmla v13.4s, v20.4s, v19.4s\n"
2228        "movi v19.4s, #0x0\n"
2229        "movi v20.4s, #0x0\n"
2230        ".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\n"
2231        ".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\n"
2232        "ldr q17, [x23, #0x30]\n"
2233        ".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\n"
2234        ".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\n"
2235        "ldr q17, [x23, #0x50]\n"
2236        ".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\n"
2237        ".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\n"
2238        "ldr q17, [x23, #0x70]\n"
2239        "add x23, x23, #0x88\n"
2240        ".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\n"
2241        ".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\n"
2242        "uzp1 v17.2d, v19.2d, v20.2d\n"
2243        "scvtf v17.4s, v17.4s, #0x4\n"
2244        "uzp2 v20.2d, v19.2d, v20.2d\n"
2245        "fmul v19.4s, v27.4s, v9.s[2]\n"
2246        "fmul v9.4s, v27.4s, v9.s[3]\n"
2247        "scvtf v20.4s, v20.4s, #0x4\n"
2248        "fmla v22.4s, v17.4s, v19.4s\n"
2249        "ldr q17, [x22, #0x10]\n"
2250        "movi v19.4s, #0x0\n"
2251        ".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\n"
2252        "fmla v23.4s, v20.4s, v9.4s\n"
2253        "movi v20.4s, #0x0\n"
2254        "movi v9.4s, #0x0\n"
2255        ".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\n"
2256        "ldr q18, [x22, #0x20]\n"
2257        ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
2258        ".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\n"
2259        ".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\n"
2260        "ldr q18, [x22, #0x40]\n"
2261        ".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\n"
2262        ".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\n"
2263        "ldr q18, [x22, #0x60]\n"
2264        ".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\n"
2265        ".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\n"
2266        "movi v18.4s, #0x0\n"
2267        ".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\n"
2268        "ldr q17, [x22, #0x30]\n"
2269        ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
2270        ".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\n"
2271        "ldr q17, [x22, #0x50]\n"
2272        ".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\n"
2273        ".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\n"
2274        "ldr q17, [x22, #0x70]\n"
2275        "add x22, x22, #0x88\n"
2276        ".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\n"
2277        ".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\n"
2278        "uzp1 v17.2d, v19.2d, v20.2d\n"
2279        "uzp2 v20.2d, v19.2d, v20.2d\n"
2280        "fmul v19.4s, v27.4s, v0.s[0]\n"
2281        "scvtf v17.4s, v17.4s, #0x4\n"
2282        "scvtf v20.4s, v20.4s, #0x4\n"
2283        "fmla v25.4s, v17.4s, v19.4s\n"
2284        "ldr q19, [x21, #0x0]\n"
2285        "fmul v17.4s, v27.4s, v0.s[1]\n"
2286        "fmla v5.4s, v20.4s, v17.4s\n"
2287        "ldr q17, [x21, #0x10]\n"
2288        "uzp1 v20.2d, v9.2d, v18.2d\n"
2289        "uzp2 v9.2d, v9.2d, v18.2d\n"
2290        "fmul v18.4s, v27.4s, v0.s[2]\n"
2291        "fmul v0.4s, v27.4s, v0.s[3]\n"
2292        "scvtf v20.4s, v20.4s, #0x4\n"
2293        "scvtf v9.4s, v9.4s, #0x4\n"
2294        "fmla v7.4s, v20.4s, v18.4s\n"
2295        "movi v20.4s, #0x0\n"
2296        "movi v18.4s, #0x0\n"
2297        ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
2298        ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
2299        "ldr q19, [x21, #0x20]\n"
2300        "fmla v4.4s, v9.4s, v0.4s\n"
2301        "movi v9.4s, #0x0\n"
2302        "movi v0.4s, #0x0\n"
2303        ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
2304        "fmul v8.4s, v27.4s, v26.s[0]\n"
2305        ".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\n"
2306        "ldr q17, [x21, #0x30]\n"
2307        ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
2308        "fmul v31.4s, v27.4s, v26.s[1]\n"
2309        ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
2310        "ldr q19, [x21, #0x40]\n"
2311        ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
2312        "fmul v15.4s, v27.4s, v26.s[2]\n"
2313        "fmul v27.4s, v27.4s, v26.s[3]\n"
2314        ".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\n"
2315        "ldr q1, [x21, #0x50]\n"
2316        ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
2317        ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
2318        "ldr q26, [x21, #0x60]\n"
2319        ".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\n"
2320        ".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\n"
2321        "ldr q21, [x21, #0x70]\n"
2322        "add x21, x21, #0x88\n"
2323        ".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\n"
2324        ".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\n"
2325        ".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\n"
2326        ".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\n"
2327        "uzp1 v29.2d, v20.2d, v18.2d\n"
2328        "uzp2 v21.2d, v20.2d, v18.2d\n"
2329        "scvtf v29.4s, v29.4s, #0x4\n"
2330        "uzp1 v18.2d, v9.2d, v0.2d\n"
2331        "uzp2 v16.2d, v9.2d, v0.2d\n"
2332        "scvtf v21.4s, v21.4s, #0x4\n"
2333        "fmla v6.4s, v29.4s, v8.4s\n"
2334        "scvtf v18.4s, v18.4s, #0x4\n"
2335        "scvtf v16.4s, v16.4s, #0x4\n"
2336        "fmla v30.4s, v21.4s, v31.4s\n"
2337        "fmla v24.4s, v18.4s, v15.4s\n"
2338        "fmla v14.4s, v16.4s, v27.4s\n"
2339        "bgt 3b\n"
2340        "mov x20, %x[res_ptr]\n"
2341        "subs x27, x27, #0x4\n"
2342        "add %x[res_ptr], %x[res_ptr], #0x10\n"
2343        "str q2, [x20, #0x0]\n"
2344        "add x20, x20, %x[res_stride]\n"
2345        "str q10, [x20, #0x0]\n"
2346        "add x20, x20, %x[res_stride]\n"
2347        "str q12, [x20, #0x0]\n"
2348        "add x20, x20, %x[res_stride]\n"
2349        "str q28, [x20, #0x0]\n"
2350        "add x20, x20, %x[res_stride]\n"
2351        "str q11, [x20, #0x0]\n"
2352        "add x20, x20, %x[res_stride]\n"
2353        "str q13, [x20, #0x0]\n"
2354        "add x20, x20, %x[res_stride]\n"
2355        "str q22, [x20, #0x0]\n"
2356        "add x20, x20, %x[res_stride]\n"
2357        "str q23, [x20, #0x0]\n"
2358        "add x20, x20, %x[res_stride]\n"
2359        "str q25, [x20, #0x0]\n"
2360        "add x20, x20, %x[res_stride]\n"
2361        "str q5, [x20, #0x0]\n"
2362        "add x20, x20, %x[res_stride]\n"
2363        "str q7, [x20, #0x0]\n"
2364        "add x20, x20, %x[res_stride]\n"
2365        "str q4, [x20, #0x0]\n"
2366        "add x20, x20, %x[res_stride]\n"
2367        "str q6, [x20, #0x0]\n"
2368        "add x20, x20, %x[res_stride]\n"
2369        "str q30, [x20, #0x0]\n"
2370        "add x20, x20, %x[res_stride]\n"
2371        "str q24, [x20, #0x0]\n"
2372        "add x20, x20, %x[res_stride]\n"
2373        "str q14, [x20, #0x0]\n"
2374        "bne 2b\n"
2375        "mov x20, #0x4\n"
2376        "sub x10, x10, #0x10\n"
2377        "cmp x10, #0x10\n"
2378        "mov %x[res_ptr], x26\n"
2379        "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
2380        "bge 1b\n"
2381        "4:"  // Row loop skip
2382        "cbz x10, 9f\n"
2383        "5:"  // Row tail: Row loop
2384        "add x24, %x[b_ptr], #0x8\n"
2385        "mov x23, %x[nc]\n"
2386        "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
2387        "6:"  // Row tail: Column loop
2388        "movi v2.16b, #0x0\n"
2389        "movi v10.16b, #0x0\n"
2390        "add x25, %x[a_ptr], #0x8\n"
2391        "mov x21, %x[nb]\n"
2392        "movi v12.16b, #0x0\n"
2393        "movi v28.16b, #0x0\n"
2394        "7:"  // Row tail: Block loop
2395        "ldr q6, [x24, #0x0]\n"
2396        "ldr q5, [x24, #0x10]\n"
2397        "movi v17.16b, #0x4\n"
2398        "movi v8.4s, #0x0\n"
2399        "ldr q4, [x25, #0x0]\n"
2400        "ldr q13, [x25, #0x10]\n"
2401        "movi v27.4s, #0x0\n"
2402        "movi v0.4s, #0x0\n"
2403        "ldr q31, [x24, #0x20]\n"
2404        "ldr q14, [x24, #0x30]\n"
2405        "movi v29.4s, #0x0\n"
2406        "movi v22.16b, #0xf0\n"
2407        "ldr q11, [x25, #0x20]\n"
2408        "ldr q23, [x25, #0x30]\n"
2409        "sshl v21.16b, v6.16b, v17.16b\n"
2410        "sshl v16.16b, v5.16b, v17.16b\n"
2411        "ldr q20, [x25, #0x40]\n"
2412        "ldr q26, [x25, #0x50]\n"
2413        "and v6.16b, v6.16b, v22.16b\n"
2414        "and v5.16b, v5.16b, v22.16b\n"
2415        "ldr q25, [x25, #0x60]\n"
2416        "ldr q3, [x25, #0x70]\n"
2417        "sshl v19.16b, v31.16b, v17.16b\n"
2418        "sshl v18.16b, v14.16b, v17.16b\n"
2419        "ldr d17, [x25, #-0x8]\n"
2420        ".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\n"
2421        ".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\n"
2422        "and v31.16b, v31.16b, v22.16b\n"
2423        ".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\n"
2424        ".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\n"
2425        "and v14.16b, v14.16b, v22.16b\n"
2426        "sub x20, x24, #0x8\n"
2427        "ldr d16, [x20, #0x0]\n"
2428        "subs x21, x21, #0x1\n"
2429        "add x25, x25, #0x88\n"
2430        "fcvtl v17.4s, v17.4h\n"
2431        "add x24, x24, #0x48\n"
2432        ".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\n"
2433        ".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\n"
2434        ".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\n"
2435        ".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\n"
2436        "fcvtl v16.4s, v16.4h\n"
2437        ".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\n"
2438        ".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\n"
2439        "fmul v23.4s, v16.4s, v17.s[0]\n"
2440        "fmul v21.4s, v16.4s, v17.s[1]\n"
2441        "fmul v1.4s, v16.4s, v17.s[2]\n"
2442        "fmul v20.4s, v16.4s, v17.s[3]\n"
2443        ".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\n"
2444        ".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\n"
2445        ".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\n"
2446        ".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\n"
2447        ".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\n"
2448        ".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\n"
2449        "uzp1 v19.2d, v8.2d, v27.2d\n"
2450        "uzp2 v18.2d, v8.2d, v27.2d\n"
2451        "scvtf v19.4s, v19.4s, #0x4\n"
2452        "uzp1 v17.2d, v0.2d, v29.2d\n"
2453        "uzp2 v16.2d, v0.2d, v29.2d\n"
2454        "scvtf v18.4s, v18.4s, #0x4\n"
2455        "fmla v2.4s, v19.4s, v23.4s\n"
2456        "scvtf v17.4s, v17.4s, #0x4\n"
2457        "scvtf v16.4s, v16.4s, #0x4\n"
2458        "fmla v10.4s, v18.4s, v21.4s\n"
2459        "fmla v12.4s, v17.4s, v1.4s\n"
2460        "fmla v28.4s, v16.4s, v20.4s\n"
2461        "bgt 7b\n"
2462        "mov x20, %x[res_ptr]\n"
2463        "cmp x10, #0x1\n"
2464        "str q2, [x20, #0x0]\n"
2465        "add x20, x20, %x[res_stride]\n"
2466        "ble 8f\n"
2467        "cmp x10, #0x2\n"
2468        "str q10, [x20, #0x0]\n"
2469        "add x20, x20, %x[res_stride]\n"
2470        "ble 8f\n"
2471        "cmp x10, #0x3\n"
2472        "str q12, [x20, #0x0]\n"
2473        "add x20, x20, %x[res_stride]\n"
2474        "ble 8f\n"
2475        "str q28, [x20, #0x0]\n"
2476        "8:"  // Row tail: Accumulator store skip
2477        "subs x23, x23, #0x4\n"
2478        "add %x[res_ptr], %x[res_ptr], #0x10\n"
2479        "bne 6b\n"
2480        "subs x10, x10, #0x4\n"
2481        "add %x[a_ptr], %x[a_ptr], x9\n"
2482        "mov %x[res_ptr], x22\n"
2483        "bgt 5b\n"
2484        "9:"  // Row tail: Row loop skip
2485        : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2486        : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2487        : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
2488    );
2489    return;
2490#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2491    ggml_gemm_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2492}
2493
2494void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2495    const int qk = QK8_0;
2496    const int nb = n / qk;
2497    const int ncols_interleaved = 8;
2498    const int blocklen = 8;
2499
2500    assert (n % qk == 0);
2501    assert (nr % 4 == 0);
2502    assert (nc % ncols_interleaved == 0);
2503
2504    UNUSED(s);
2505    UNUSED(bs);
2506    UNUSED(vx);
2507    UNUSED(vy);
2508    UNUSED(nr);
2509    UNUSED(nc);
2510    UNUSED(nb);
2511    UNUSED(ncols_interleaved);
2512    UNUSED(blocklen);
2513
2514#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2515#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2516    if (ggml_cpu_get_sve_cnt() == QK8_0) {
2517        const void * b_ptr = vx;
2518        const void * a_ptr = vy;
2519        float * res_ptr = s;
2520        size_t res_stride = bs * sizeof(float);
2521
2522        __asm__ __volatile__(
2523            "mov x20, #0x4\n"
2524            "mov x13, %x[nr]\n"
2525            "mov z28.s, #-0x4\n"
2526            "mov x12, #0x88\n"
2527            "ptrue p1.b\n"
2528            "whilelt p0.s, XZR, x20\n"
2529            "cmp x13, #0x10\n"
2530            "mul x12, %x[nb], x12\n"
2531            "blt 4f\n"
2532            "1:"  // Row loop
2533            "add x11, %x[b_ptr], #0x10\n"
2534            "mov x10, %x[nc]\n"
2535            "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
2536            "2:"  // Column loop
2537            "add x28, %x[a_ptr], #0x8\n"
2538            "mov z24.b, #0x0\n"
2539            "mov z15.b, #0x0\n"
2540            "mov x27, %x[nb]\n"
2541            "add x26, x28, x12\n"
2542            "mov z12.b, #0x0\n"
2543            "mov z0.b, #0x0\n"
2544            "add x25, x26, x12\n"
2545            "mov z13.b, #0x0\n"
2546            "mov z1.b, #0x0\n"
2547            "add x24, x25, x12\n"
2548            "mov z20.b, #0x0\n"
2549            "mov z25.b, #0x0\n"
2550            "mov z11.b, #0x0\n"
2551            "mov z16.b, #0x0\n"
2552            "mov z19.b, #0x0\n"
2553            "mov z26.b, #0x0\n"
2554            "mov z8.b, #0x0\n"
2555            "mov z29.b, #0x0\n"
2556            "mov z27.b, #0x0\n"
2557            "mov z10.b, #0x0\n"
2558            "3:"  // Block loop
2559            "ld1b { z30.b }, p1/Z, [x11]\n"
2560            "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
2561            "mov z18.s, #0x0\n"
2562            "mov z7.s, #0x0\n"
2563            "ld1rqb { z3.b }, p1/Z, [x28]\n"
2564            "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
2565            "mov z9.s, #0x0\n"
2566            "mov z22.s, #0x0\n"
2567            "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
2568            "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
2569            "sub x20, x11, #0x10\n"
2570            "sub x23, x28, #0x8\n"
2571            "lsl z31.b, z30.b, #0x4\n"
2572            "lsl z6.b, z21.b, #0x4\n"
2573            "ld1h { z23.s }, p1/Z, [x20]\n"
2574            "sub x22, x26, #0x8\n"
2575            "and z30.b, z30.b, #0xf0\n"
2576            "and z21.b, z21.b, #0xf0\n"
2577            "sub x21, x25, #0x8\n"
2578            "sub x20, x24, #0x8\n"
2579            "lsl z14.b, z4.b, #0x4\n"
2580            "lsl z2.b, z17.b, #0x4\n"
2581            "subs x27, x27, #0x1\n"
2582            "add x11, x11, #0x90\n"
2583            ".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\n"
2584            ".inst 0x45069867  // smmla z7.s, z3.b, z6.b\n"
2585            "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
2586            "and z4.b, z4.b, #0xf0\n"
2587            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
2588            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
2589            "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
2590            "and z17.b, z17.b, #0xf0\n"
2591            "fcvt z23.s, p1/m, z23.h\n"
2592            ".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\n"
2593            ".inst 0x45029867  // smmla z7.s, z3.b, z2.b\n"
2594            "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
2595            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
2596            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
2597            "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
2598            "fscale z23.s, p1/m, z23.s, z28.s\n"
2599            ".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\n"
2600            ".inst 0x45159867  // smmla z7.s, z3.b, z21.b\n"
2601            "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
2602            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
2603            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
2604            "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
2605            "add x28, x28, #0x88\n"
2606            ".inst 0x45049872  // smmla z18.s, z3.b, z4.b\n"
2607            ".inst 0x45119867  // smmla z7.s, z3.b, z17.b\n"
2608            "ld1h { z3.s }, p0/Z, [x23]\n"
2609            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
2610            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
2611            "fcvt z3.s, p1/m, z3.h\n"
2612            "uzp1 z5.d, z18.d, z7.d\n"
2613            "uzp2 z18.d, z18.d, z7.d\n"
2614            "mov z3.q, z3.q[0]\n"
2615            "uzp1 z7.d, z9.d, z22.d\n"
2616            "uzp2 z22.d, z9.d, z22.d\n"
2617            "fmul z9.s, z23.s, z3.s[0]\n"
2618            "scvtf z5.s, p1/m, z5.s\n"
2619            "scvtf z18.s, p1/m, z18.s\n"
2620            "scvtf z7.s, p1/m, z7.s\n"
2621            "scvtf z22.s, p1/m, z22.s\n"
2622            "fmla z24.s, p1/M, z5.s, z9.s\n"
2623            "ld1rqb { z5.b }, p1/Z, [x26]\n"
2624            "fmul z9.s, z23.s, z3.s[1]\n"
2625            "fmla z15.s, p1/M, z18.s, z9.s\n"
2626            "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
2627            "fmul z9.s, z23.s, z3.s[2]\n"
2628            "fmul z3.s, z23.s, z3.s[3]\n"
2629            "fmla z12.s, p1/M, z7.s, z9.s\n"
2630            "mov z9.s, #0x0\n"
2631            "ld1h { z7.s }, p0/Z, [x22]\n"
2632            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
2633            "fmla z0.s, p1/M, z22.s, z3.s\n"
2634            "mov z22.s, #0x0\n"
2635            "ld1h { z3.s }, p0/Z, [x21]\n"
2636            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
2637            "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
2638            "fcvt z7.s, p1/m, z7.h\n"
2639            "fcvt z3.s, p1/m, z3.h\n"
2640            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
2641            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
2642            "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
2643            "mov z7.q, z7.q[0]\n"
2644            "mov z3.q, z3.q[0]\n"
2645            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
2646            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
2647            "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
2648            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
2649            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
2650            "uzp1 z5.d, z9.d, z22.d\n"
2651            "scvtf z5.s, p1/m, z5.s\n"
2652            "uzp2 z22.d, z9.d, z22.d\n"
2653            "fmul z9.s, z23.s, z7.s[0]\n"
2654            "scvtf z22.s, p1/m, z22.s\n"
2655            "fmla z13.s, p1/M, z5.s, z9.s\n"
2656            "ld1rqb { z9.b }, p1/Z, [x25]\n"
2657            "fmul z5.s, z23.s, z7.s[1]\n"
2658            "fmla z1.s, p1/M, z22.s, z5.s\n"
2659            "mov z5.s, #0x0\n"
2660            "mov z22.s, #0x0\n"
2661            ".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\n"
2662            ".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\n"
2663            "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
2664            ".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\n"
2665            ".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\n"
2666            "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
2667            ".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\n"
2668            ".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\n"
2669            "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
2670            "add x26, x26, #0x88\n"
2671            ".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\n"
2672            ".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\n"
2673            "uzp1 z18.d, z5.d, z22.d\n"
2674            "scvtf z18.s, p1/m, z18.s\n"
2675            "uzp2 z22.d, z5.d, z22.d\n"
2676            "fmul z5.s, z23.s, z7.s[2]\n"
2677            "fmul z7.s, z23.s, z7.s[3]\n"
2678            "scvtf z22.s, p1/m, z22.s\n"
2679            "fmla z20.s, p1/M, z18.s, z5.s\n"
2680            "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
2681            "ld1h { z5.s }, p0/Z, [x20]\n"
2682            "fcvt z5.s, p1/m, z5.h\n"
2683            "fmla z25.s, p1/M, z22.s, z7.s\n"
2684            "mov z22.s, #0x0\n"
2685            "mov z7.s, #0x0\n"
2686            ".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\n"
2687            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
2688            "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
2689            "mov z5.q, z5.q[0]\n"
2690            ".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\n"
2691            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
2692            "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
2693            ".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\n"
2694            ".inst 0x45159927  // smmla z7.s, z9.b, z21.b\n"
2695            "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
2696            ".inst 0x45049936  // smmla z22.s, z9.b, z4.b\n"
2697            ".inst 0x45119927  // smmla z7.s, z9.b, z17.b\n"
2698            "uzp1 z9.d, z22.d, z7.d\n"
2699            "scvtf z9.s, p1/m, z9.s\n"
2700            "uzp2 z22.d, z22.d, z7.d\n"
2701            "fmul z7.s, z23.s, z3.s[0]\n"
2702            "scvtf z22.s, p1/m, z22.s\n"
2703            "fmla z11.s, p1/M, z9.s, z7.s\n"
2704            "ld1rqb { z9.b }, p1/Z, [x24]\n"
2705            "fmul z7.s, z23.s, z3.s[1]\n"
2706            "fmla z16.s, p1/M, z22.s, z7.s\n"
2707            "mov z22.s, #0x0\n"
2708            "mov z7.s, #0x0\n"
2709            ".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\n"
2710            ".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\n"
2711            "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
2712            ".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\n"
2713            ".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\n"
2714            "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
2715            ".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\n"
2716            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
2717            "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
2718            "add x25, x25, #0x88\n"
2719            ".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\n"
2720            ".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\n"
2721            "uzp1 z18.d, z22.d, z7.d\n"
2722            "scvtf z18.s, p1/m, z18.s\n"
2723            "uzp2 z7.d, z22.d, z7.d\n"
2724            "fmul z22.s, z23.s, z3.s[2]\n"
2725            "fmul z3.s, z23.s, z3.s[3]\n"
2726            "scvtf z7.s, p1/m, z7.s\n"
2727            "fmla z19.s, p1/M, z18.s, z22.s\n"
2728            "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
2729            "fmul z22.s, z23.s, z5.s[0]\n"
2730            "fmla z26.s, p1/M, z7.s, z3.s\n"
2731            "mov z3.s, #0x0\n"
2732            "mov z7.s, #0x0\n"
2733            ".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\n"
2734            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
2735            "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
2736            ".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\n"
2737            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
2738            "mov z9.s, #0x0\n"
2739            ".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\n"
2740            "mov z31.s, #0x0\n"
2741            ".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\n"
2742            "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
2743            "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
2744            ".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\n"
2745            "fmul z14.s, z23.s, z5.s[1]\n"
2746            ".inst 0x450298df  // smmla z31.s, z6.b, z2.b\n"
2747            "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
2748            "fmul z2.s, z23.s, z5.s[2]\n"
2749            "fmul z23.s, z23.s, z5.s[3]\n"
2750            ".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\n"
2751            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
2752            "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
2753            ".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\n"
2754            ".inst 0x451598df  // smmla z31.s, z6.b, z21.b\n"
2755            "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
2756            "add x24, x24, #0x88\n"
2757            ".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\n"
2758            ".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\n"
2759            ".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\n"
2760            ".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\n"
2761            "uzp1 z18.d, z3.d, z7.d\n"
2762            "uzp2 z5.d, z3.d, z7.d\n"
2763            "scvtf z18.s, p1/m, z18.s\n"
2764            "uzp1 z6.d, z9.d, z31.d\n"
2765            "uzp2 z9.d, z9.d, z31.d\n"
2766            "scvtf z5.s, p1/m, z5.s\n"
2767            "fmla z8.s, p1/M, z18.s, z22.s\n"
2768            "scvtf z6.s, p1/m, z6.s\n"
2769            "scvtf z9.s, p1/m, z9.s\n"
2770            "fmla z29.s, p1/M, z5.s, z14.s\n"
2771            "fmla z27.s, p1/M, z6.s, z2.s\n"
2772            "fmla z10.s, p1/M, z9.s, z23.s\n"
2773            "bgt 3b\n"
2774            "mov x20, %x[res_ptr]\n"
2775            "subs x10, x10, #0x8\n"
2776            "add %x[res_ptr], %x[res_ptr], #0x20\n"
2777            "st1w { z24.s }, p1, [x20]\n"
2778            "add x20, x20, %x[res_stride]\n"
2779            "st1w { z15.s }, p1, [x20]\n"
2780            "add x20, x20, %x[res_stride]\n"
2781            "st1w { z12.s }, p1, [x20]\n"
2782            "add x20, x20, %x[res_stride]\n"
2783            "st1w { z0.s }, p1, [x20]\n"
2784            "add x20, x20, %x[res_stride]\n"
2785            "st1w { z13.s }, p1, [x20]\n"
2786            "add x20, x20, %x[res_stride]\n"
2787            "st1w { z1.s }, p1, [x20]\n"
2788            "add x20, x20, %x[res_stride]\n"
2789            "st1w { z20.s }, p1, [x20]\n"
2790            "add x20, x20, %x[res_stride]\n"
2791            "st1w { z25.s }, p1, [x20]\n"
2792            "add x20, x20, %x[res_stride]\n"
2793            "st1w { z11.s }, p1, [x20]\n"
2794            "add x20, x20, %x[res_stride]\n"
2795            "st1w { z16.s }, p1, [x20]\n"
2796            "add x20, x20, %x[res_stride]\n"
2797            "st1w { z19.s }, p1, [x20]\n"
2798            "add x20, x20, %x[res_stride]\n"
2799            "st1w { z26.s }, p1, [x20]\n"
2800            "add x20, x20, %x[res_stride]\n"
2801            "st1w { z8.s }, p1, [x20]\n"
2802            "add x20, x20, %x[res_stride]\n"
2803            "st1w { z29.s }, p1, [x20]\n"
2804            "add x20, x20, %x[res_stride]\n"
2805            "st1w { z27.s }, p1, [x20]\n"
2806            "add x20, x20, %x[res_stride]\n"
2807            "st1w { z10.s }, p1, [x20]\n"
2808            "bne 2b\n"
2809            "mov x20, #0x4\n"
2810            "sub x13, x13, #0x10\n"
2811            "cmp x13, #0x10\n"
2812            "mov %x[res_ptr], x9\n"
2813            "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
2814            "bge 1b\n"
2815            "4:"  // Row loop skip
2816            "cbz x13, 9f\n"
2817            "5:"  // Row tail: Row loop
2818            "add x25, %x[b_ptr], #0x10\n"
2819            "mov x24, %x[nc]\n"
2820            "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
2821            "6:"  // Row tail: Column loop
2822            "mov z24.b, #0x0\n"
2823            "mov z15.b, #0x0\n"
2824            "add x28, %x[a_ptr], #0x8\n"
2825            "mov x22, %x[nb]\n"
2826            "mov z12.b, #0x0\n"
2827            "mov z0.b, #0x0\n"
2828            "7:"  // Row tail: Block loop
2829            "ld1b { z3.b }, p1/Z, [x25]\n"
2830            "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
2831            "mov z2.s, #0x0\n"
2832            "mov z25.s, #0x0\n"
2833            "ld1rqb { z26.b }, p1/Z, [x28]\n"
2834            "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
2835            "mov z27.s, #0x0\n"
2836            "mov z19.s, #0x0\n"
2837            "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
2838            "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
2839            "sub x21, x25, #0x10\n"
2840            "sub x20, x28, #0x8\n"
2841            "lsl z20.b, z3.b, #0x4\n"
2842            "lsl z4.b, z6.b, #0x4\n"
2843            "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
2844            "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
2845            "and z3.b, z3.b, #0xf0\n"
2846            "and z6.b, z6.b, #0xf0\n"
2847            "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
2848            "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
2849            "lsl z8.b, z29.b, #0x4\n"
2850            "lsl z14.b, z16.b, #0x4\n"
2851            "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
2852            "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
2853            ".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\n"
2854            ".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\n"
2855            "and z29.b, z29.b, #0xf0\n"
2856            "ld1h { z17.s }, p1/Z, [x21]\n"
2857            ".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\n"
2858            ".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\n"
2859            "and z16.b, z16.b, #0xf0\n"
2860            "ld1h { z4.s }, p0/Z, [x20]\n"
2861            "subs x22, x22, #0x1\n"
2862            "add x28, x28, #0x88\n"
2863            "fcvt z17.s, p1/m, z17.h\n"
2864            "add x25, x25, #0x90\n"
2865            ".inst 0x45089942  // smmla z2.s, z10.b, z8.b\n"
2866            ".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\n"
2867            "fcvt z4.s, p1/m, z4.h\n"
2868            ".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\n"
2869            ".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\n"
2870            "fscale z17.s, p1/m, z17.s, z28.s\n"
2871            "mov z4.q, z4.q[0]\n"
2872            ".inst 0x45039962  // smmla z2.s, z11.b, z3.b\n"
2873            ".inst 0x45069979  // smmla z25.s, z11.b, z6.b\n"
2874            "fmul z23.s, z17.s, z4.s[0]\n"
2875            "fmul z9.s, z17.s, z4.s[1]\n"
2876            "fmul z21.s, z17.s, z4.s[2]\n"
2877            "fmul z4.s, z17.s, z4.s[3]\n"
2878            ".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\n"
2879            ".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\n"
2880            ".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\n"
2881            ".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\n"
2882            ".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\n"
2883            ".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\n"
2884            "uzp1 z31.d, z2.d, z25.d\n"
2885            "uzp2 z13.d, z2.d, z25.d\n"
2886            "scvtf z31.s, p1/m, z31.s\n"
2887            "uzp1 z17.d, z27.d, z19.d\n"
2888            "uzp2 z18.d, z27.d, z19.d\n"
2889            "scvtf z13.s, p1/m, z13.s\n"
2890            "fmla z24.s, p1/M, z31.s, z23.s\n"
2891            "scvtf z17.s, p1/m, z17.s\n"
2892            "scvtf z18.s, p1/m, z18.s\n"
2893            "fmla z15.s, p1/M, z13.s, z9.s\n"
2894            "fmla z12.s, p1/M, z17.s, z21.s\n"
2895            "fmla z0.s, p1/M, z18.s, z4.s\n"
2896            "bgt 7b\n"
2897            "mov x20, %x[res_ptr]\n"
2898            "cmp x13, #0x1\n"
2899            "st1w { z24.s }, p1, [x20]\n"
2900            "add x20, x20, %x[res_stride]\n"
2901            "ble 8f\n"
2902            "cmp x13, #0x2\n"
2903            "st1w { z15.s }, p1, [x20]\n"
2904            "add x20, x20, %x[res_stride]\n"
2905            "ble 8f\n"
2906            "cmp x13, #0x3\n"
2907            "st1w { z12.s }, p1, [x20]\n"
2908            "add x20, x20, %x[res_stride]\n"
2909            "ble 8f\n"
2910            "st1w { z0.s }, p1, [x20]\n"
2911            "8:"  // Row tail: Accumulator store skip
2912            "subs x24, x24, #0x8\n"
2913            "add %x[res_ptr], %x[res_ptr], #0x20\n"
2914            "bne 6b\n"
2915            "subs x13, x13, #0x4\n"
2916            "add %x[a_ptr], %x[a_ptr], x12\n"
2917            "mov %x[res_ptr], x23\n"
2918            "bgt 5b\n"
2919            "9:"  // Row tail: Row loop skip
2920            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
2921            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
2922            : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
2923        );
2924        return;
2925    }
2926#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2927
2928#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2929    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2930}
2931
2932void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2933    const int qk = QK8_0;
2934    const int nb = n / qk;
2935    const int ncols_interleaved = 4;
2936    const int blocklen = 4;
2937
2938    assert (n % qk == 0);
2939    assert (nr % 4 == 0);
2940    assert (nc % ncols_interleaved == 0);
2941
2942    UNUSED(s);
2943    UNUSED(bs);
2944    UNUSED(vx);
2945    UNUSED(vy);
2946    UNUSED(nr);
2947    UNUSED(nc);
2948    UNUSED(nb);
2949    UNUSED(ncols_interleaved);
2950    UNUSED(blocklen);
2951
2952#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2953    const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
2954
2955    for (int y = 0; y < nr / 4; y++) {
2956        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2957        for (int x = 0; x < nc / ncols_interleaved; x++) {
2958            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
2959
2960            float32x4_t sumf[4];
2961            for (int m = 0; m < 4; m++) {
2962                sumf[m] = vdupq_n_f32(0);
2963            }
2964
2965            for (int l = 0; l < nb; l++) {
2966                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
2967                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
2968
2969                int32x4_t sumi_0 = vdupq_n_s32(0);
2970                int32x4_t sumi_1 = vdupq_n_s32(0);
2971                int32x4_t sumi_2 = vdupq_n_s32(0);
2972                int32x4_t sumi_3 = vdupq_n_s32(0);
2973
2974                for (int k = 0; k < 4; k++) {
2975                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
2976                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
2977
2978                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
2979                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
2980                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
2981
2982                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
2983                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
2984                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
2985                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
2986                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
2987                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
2988                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
2989                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
2990                }
2991
2992                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
2993                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
2994                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
2995                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
2996            }
2997
2998            for (int m = 0; m < 4; m++) {
2999                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3000            }
3001        }
3002    }
3003    return;
3004#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3005    ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
3006}
3007
3008void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3009    constexpr int qk = QK_K;
3010    const int     nb = n / qk;
3011
3012    constexpr int ncols_interleaved = 8;
3013    constexpr int blocklen          = 4;
3014
3015    assert(n % qk == 0);
3016    assert(nr % 4 == 0);
3017    assert(nc % ncols_interleaved == 0);
3018
3019    UNUSED(nb);
3020    UNUSED(ncols_interleaved);
3021    UNUSED(blocklen);
3022
3023#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3024    constexpr int    q8_k_blocklen = 4;
3025    constexpr int    acc_size  = 2 * 4;  // 2 row pairs × 4 col pairs
3026    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
3027
3028    // 8 accumulators: 2 row pairs × 4 col pairs
3029    float32x4_t acc_f32[acc_size];
3030
3031    for (int y = 0; y < nr / q8_k_blocklen; y++) {
3032        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3033
3034        for (int x = 0; x < nc / ncols_interleaved; x++) {
3035            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3036
3037            for (int i = 0; i < acc_size; i++) {
3038                acc_f32[i] = vdupq_n_f32(0);
3039            }
3040
3041            for (int b = 0; b < nb; b++) {
3042                // d4 0 1 2 3, 4 5 6 7
3043                float32x4_t q4_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
3044                float32x4_t q4_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
3045                // d8 0 1 2 3
3046                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);
3047                // mins
3048                float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
3049                float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
3050
3051                // Precomputation of scales and mins
3052                float32x4_t sbd_scale_0123[q8_k_blocklen];
3053                float32x4_t sbd_scale_4567[q8_k_blocklen];
3054                float32x4_t sbd_min_0123[q8_k_blocklen];
3055                float32x4_t sbd_min_4567[q8_k_blocklen];
3056
3057                sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
3058                sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
3059                sbd_min_0123[0]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
3060                sbd_min_4567[0]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
3061
3062                sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
3063                sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
3064                sbd_min_0123[1]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
3065                sbd_min_4567[1]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
3066
3067                sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
3068                sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
3069                sbd_min_0123[2]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
3070                sbd_min_4567[2]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
3071
3072                sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
3073                sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
3074                sbd_min_0123[3]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
3075                sbd_min_4567[3]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
3076
3077                // Precomputation of bsums, each vpaddq calcs all the bsums for each row
3078                const int16x8_t bsums[q8_k_blocklen] = {
3079                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3080                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3081                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3082                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3083                };
3084                int16_t bsums_arr[QK_K / 64][8];
3085                for (int q8_row = 0; q8_row < 4; q8_row++) {
3086                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3087                }
3088
3089                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
3090                int32x4_t bias_acc[acc_size];
3091                for (int i = 0; i < acc_size; i++) {
3092                    bias_acc[i] = vdupq_n_s32(0);
3093                }
3094
3095                for (int sb = 0; sb < QK_K / 64; sb++) {
3096                    // Int accumulators for qs vecdot (4 row x 2 col quartets)
3097                    int32x4_t acc_lo[acc_size];
3098                    int32x4_t acc_hi[acc_size];
3099                    for (int i = 0; i < acc_size; i++) {
3100                        acc_lo[i] = vdupq_n_s32(0);
3101                        acc_hi[i] = vdupq_n_s32(0);
3102                    }
3103                    // Need scales for the low and high nibbles
3104                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3105                    int16x8_t q4sb_scales[2];
3106                    int16x8_t q4sb_mins[2];
3107                    for (int i = 0; i < 2; i++) {
3108                        int8_t    aux_q4sb[8];
3109                        const int offset = sb * 24 + i * 12;
3110                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
3111                        q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
3112                    }
3113
3114                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows
3115                    for (int k = 0; k < reads_per_sb; k++) {
3116                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
3117                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
3118
3119                        // 0..3 & 32..35
3120                        const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
3121                        const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
3122
3123                        const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
3124                        const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
3125
3126                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123
3127                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123
3128                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123
3129                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123
3130
3131                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123
3132                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123
3133                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123
3134                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123
3135
3136                        const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
3137                        const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
3138
3139                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567
3140                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567
3141                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567
3142                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567
3143
3144                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567
3145                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567
3146                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567
3147                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567
3148                    }
3149
3150                    // Scale and bias application
3151                    // acc is stored interleaved to match output layout
3152                    const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
3153                    const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
3154                    const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
3155                    const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
3156                    for (int row = 0; row < q8_k_blocklen; row++) {
3157                        // Bias correction
3158                        // row c0123 blk0 and blk1
3159                        const float32x4_t sumf_0123 =
3160                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
3161                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
3162                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
3163
3164                        // row c4567 blk0 and blk1
3165                        const float32x4_t sumf_4567 =
3166                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
3167                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
3168                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
3169
3170                        // Bias
3171                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
3172                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
3173
3174                        // row c0123 blk0 and blk1
3175                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
3176                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
3177
3178                        // row c4567 blk0 and blk1
3179                        bias_acc[2 * row + 1] =
3180                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
3181                        bias_acc[2 * row + 1] =
3182                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
3183                    }
3184                }  // for sb
3185
3186                for (int row = 0; row < q8_k_blocklen; row++) {
3187                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
3188                    acc_f32[2 * row + 1] =
3189                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
3190                }
3191            }  // for b
3192
3193            for (int i = 0; i < q8_k_blocklen; i++) {
3194                int row = y * q8_k_blocklen + i;
3195                for (int j = 0; j < 2; j++) {
3196                    int col    = x * ncols_interleaved + j * 4;
3197                    int offset = row * bs + col;
3198                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
3199                }
3200            }
3201        }  // for x
3202    }  // for y
3203    return;
3204#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3205    ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3206}
3207
3208void ggml_gemm_q4_K_8x8_q8_K(int                        n,
3209                             float * GGML_RESTRICT      s,
3210                             size_t                     bs,
3211                             const void * GGML_RESTRICT vx,
3212                             const void * GGML_RESTRICT vy,
3213                             int                        nr,
3214                             int                        nc) {
3215    constexpr int qk = QK_K;
3216    const int     nb = n / qk;
3217
3218    constexpr int ncols_interleaved = 8;
3219    constexpr int blocklen          = 8;
3220
3221    assert(n % qk == 0);
3222    assert(nr % 4 == 0);
3223    assert(nc % ncols_interleaved == 0);
3224
3225    UNUSED(nb);
3226    UNUSED(ncols_interleaved);
3227    UNUSED(blocklen);
3228
3229#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3230    constexpr int    q8_k_blocklen = 4;
3231    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
3232
3233    // 8 accumulators: 2 row pairs × 4 col pairs
3234    float32x4_t acc_f32[blocklen];
3235
3236    for (int y = 0; y < nr / q8_k_blocklen; y++) {
3237        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3238
3239        for (int x = 0; x < nc / ncols_interleaved; x++) {
3240            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
3241
3242            for (int i = 0; i < blocklen; i++) {
3243                acc_f32[i] = vdupq_n_f32(0);
3244            }
3245
3246            for (int b = 0; b < nb; b++) {
3247                // bsums pairs belongs to the same q8_k subblock
3248                const int16x8_t bsums[4]{
3249                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3250                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3251                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3252                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3253                };
3254                int16_t bsums_arr[4][8];
3255                for (int q8_row = 0; q8_row < 4; q8_row++) {
3256                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3257                }
3258
3259                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
3260                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
3261                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
3262                for (int i = 0; i < 8; i++) {
3263                    acc[i]      = vdupq_n_s32(0);
3264                    bias_acc[i] = vdupq_n_s32(0);
3265                }
3266
3267                for (int sb = 0; sb < QK_K / 64; sb++) {
3268                    // Need scales for the low and high nibbles
3269                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3270                    int8_t    q4sb_scales[2][8];
3271                    int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
3272                    for (int i = 0; i < 2; i++) {
3273                        const int offset = sb * 24 + i * 12;
3274                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
3275                    }
3276
3277                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
3278                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
3279
3280                    int8x16_t q8_qs_01[8];
3281                    int8x16_t q8_qs_23[8];
3282
3283                    // Load 32-byte per row pair, 1 subblock each time
3284                    for (int i = 0; i < 8; i++) {
3285                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
3286                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
3287                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
3288                    }
3289
3290                    const int8x16_t q8s[2][8] = {
3291                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
3292                          q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
3293                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
3294                          q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
3295                    };
3296
3297                    // Q4s columns iterated in pairs (01, 23, 45, 67)
3298                    for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
3299                        for (int i = 0; i < 4; i++) {
3300                            sb_acc[i] = vdupq_n_s32(0);
3301                        }
3302
3303                        uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
3304                        uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
3305                        uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
3306                        uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
3307                        const int8x16_t q4_nibbles[2][4] = {
3308                            {
3309                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
3310                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
3311                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
3312                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
3313                            },
3314                            {
3315                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
3316                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
3317                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
3318                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
3319                            }
3320                        };
3321
3322                        // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
3323                        // for each of the internal 32 qs subblock (blk)
3324                        for (int rp = 0; rp < 2; rp++) {
3325                            for (int blk = 0; blk < 2; blk++) {
3326                                const int8x16_t * q8  = &q8s[rp][4 * blk];
3327                                const int8x16_t * q4  = q4_nibbles[blk];
3328                                int32x4_t         acc = sb_acc[2 * rp + blk];
3329                                // mul add for each qs in the same subblock
3330                                for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
3331                                    acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
3332                                }
3333                                sb_acc[2 * rp + blk] = acc;
3334                            }
3335                        }
3336
3337                        // Scales[i] corresponds to column i
3338                        const int scale_offset = cp * 2;
3339                        const int32_t scale_00 = q4sb_scales[0][scale_offset];
3340                        const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
3341                        const int32_t scale_10 = q4sb_scales[1][scale_offset];
3342                        const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
3343                        const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
3344                        const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
3345
3346                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
3347                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
3348                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
3349                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
3350                    }
3351
3352                    // Multiply Acc bsum + mins
3353                    for (int q8_row = 0; q8_row < 4; q8_row++) {
3354                        // Each pair of subblocks share the same bsums
3355                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
3356                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
3357                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
3358
3359                        bias_acc[2 * q8_row] =
3360                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
3361                        bias_acc[2 * q8_row] =
3362                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
3363                        bias_acc[2 * q8_row + 1] =
3364                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
3365                        bias_acc[2 * q8_row + 1] =
3366                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
3367                    }
3368                }  // for sb
3369
3370                // Reorder of i8mm output with bias and output layout
3371                for (int i = 0; i < 8; i++) {
3372                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
3373                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
3374                }
3375                int32x4_t reorder_acc[8] = {
3376                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
3377                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
3378                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
3379                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
3380                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
3381                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
3382                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
3383                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
3384                };
3385
3386                for (int i = 0; i < q8_k_blocklen; i++) {
3387                    for (int j = 0; j < 2; j++) {
3388                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
3389                        float32x4_t       q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
3390                        const float32x4_t dmins   = vmulq_f32(q4_dmin, q8_d);
3391
3392                        float32x4_t       q4_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
3393                        const float32x4_t scale = vmulq_f32(q4_d, q8_d);
3394
3395                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
3396                        acc_f32[2 * i + j] =
3397                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
3398                    }
3399                }
3400            }  // for b
3401
3402            // With the previous reorder, the tile is already in the correct memory layout.
3403            for (int i = 0; i < q8_k_blocklen; i++) {
3404                int row = y * q8_k_blocklen + i;
3405                for (int j = 0; j < 2; j++) {
3406                    int col    = x * ncols_interleaved + j * 4;
3407                    int offset = row * bs + col;
3408                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
3409                }
3410            }
3411        }  // for x
3412    }  // for y
3413    return;
3414#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3415    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3416}
3417
3418void ggml_gemm_q5_K_8x8_q8_K(int                        n,
3419                             float * GGML_RESTRICT      s,
3420                             size_t                     bs,
3421                             const void * GGML_RESTRICT vx,
3422                             const void * GGML_RESTRICT vy,
3423                             int                        nr,
3424                             int                        nc) {
3425    constexpr int qk = QK_K;
3426    const int     nb = n / qk;
3427
3428    constexpr int ncols_interleaved = 8;
3429    constexpr int blocklen          = 8;
3430
3431    assert(n % qk == 0);
3432    assert(nr % 4 == 0);
3433    assert(nc % ncols_interleaved == 0);
3434
3435    UNUSED(nb);
3436    UNUSED(ncols_interleaved);
3437    UNUSED(blocklen);
3438
3439#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3440    constexpr int    q8_k_blocklen = 4;
3441    constexpr int    col_pairs     = ncols_interleaved / 2;
3442    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
3443    const uint8x16_t mone          = vdupq_n_u8(1);
3444    const uint8x16_t mtwo          = vdupq_n_u8(2);
3445
3446    // 8 accumulators: 2 row pairs × 4 col pairs
3447    float32x4_t acc_f32[blocklen];
3448
3449    for (int y = 0; y < nr / q8_k_blocklen; y++) {
3450        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3451
3452        for (int x = 0; x < nc / ncols_interleaved; x++) {
3453            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
3454
3455            for (int i = 0; i < blocklen; i++) {
3456                acc_f32[i] = vdupq_n_f32(0);
3457            }
3458
3459            for (int b = 0; b < nb; b++) {
3460                // bsums pairs belongs to the same q8_k subblock
3461                const int16x8_t bsums[4]{
3462                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
3463                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
3464                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
3465                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
3466                };
3467                int16_t bsums_arr[4][8];
3468                for (int q8_row = 0; q8_row < 4; q8_row++) {
3469                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
3470                }
3471
3472                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
3473                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
3474                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
3475                for (int i = 0; i < 8; i++) {
3476                    acc[i]      = vdupq_n_s32(0);
3477                    bias_acc[i] = vdupq_n_s32(0);
3478                }
3479
3480                // Load qh once per block and shift after each subblock
3481                const uint8_t * qh_base = q5_ptr[b].qh;
3482                uint8x16_t      qh[col_pairs][4];
3483                for (int cp = 0; cp < col_pairs; cp++) {
3484                    qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
3485                    qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
3486                    qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
3487                    qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
3488                }
3489
3490                for (int sb = 0; sb < QK_K / 64; sb++) {
3491                    // Need scales for the low and high nibbles
3492                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
3493                    int8_t    q5sb_scales[2][8];
3494                    int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
3495                    for (int i = 0; i < 2; i++) {
3496                        const int offset = sb * 24 + i * 12;
3497                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
3498                    }
3499
3500                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
3501                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
3502
3503                    int8x16_t q8_qs_01[8];
3504                    int8x16_t q8_qs_23[8];
3505
3506                    // Load 32-byte per row pair, 1 subblock each time
3507                    for (int i = 0; i < 8; i++) {
3508                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
3509                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
3510                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
3511                    }
3512
3513                    const int8x16_t q8s[2][8] = {
3514                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
3515                         q8_qs_01[7] },
3516                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
3517                         q8_qs_23[7] },
3518                    };
3519
3520                    // Q5s columns iterated in pairs (01, 23, 45, 67)
3521                    for (int cp = 0; cp < col_pairs; cp++) {
3522                        for (int i = 0; i < 4; i++) {
3523                            sb_acc[i] = vdupq_n_s32(0);
3524                        }
3525
3526                        uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
3527                        uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
3528                        uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
3529                        uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
3530
3531                        // This is the only part of the algorithm that differs with Q4_K
3532                        // Extract High bits and pack into 5 bit weights
3533                        uint8x16_t hbit_lo_0    = vandq_u8(qh[cp][0], mone);
3534                        uint8x16_t hbit_hi_0    = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
3535                        qh[cp][0]               = vshrq_n_u8(qh[cp][0], 2);
3536                        // Same as Q4_K, i8mm to dequantize the weights.
3537                        const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
3538                        int32x4_t       acc_0   = sb_acc[0];
3539                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
3540                        int32x4_t acc_2         = sb_acc[2];
3541                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
3542                        const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
3543                        int32x4_t       acc_1   = sb_acc[1];
3544                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
3545                        int32x4_t acc_3         = sb_acc[3];
3546                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
3547
3548                        // Repeat for the other 3 columns (8..15, 16..23, 24..31)
3549                        uint8x16_t hbit_hi_1    = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
3550                        uint8x16_t hbit_lo_1    = vandq_u8(qh[cp][1], mone);
3551                        qh[cp][1]               = vshrq_n_u8(qh[cp][1], 2);
3552                        const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
3553                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
3554                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
3555                        const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
3556                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
3557                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
3558
3559                        uint8x16_t hbit_hi_2    = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
3560                        uint8x16_t hbit_lo_2    = vandq_u8(qh[cp][2], mone);
3561                        qh[cp][2]               = vshrq_n_u8(qh[cp][2], 2);
3562                        const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
3563                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
3564                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
3565                        const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
3566                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
3567                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
3568
3569                        uint8x16_t hbit_lo_3    = vandq_u8(qh[cp][3], mone);
3570                        uint8x16_t hbit_hi_3    = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
3571                        qh[cp][3]               = vshrq_n_u8(qh[cp][3], 2);
3572                        const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
3573                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
3574                        sb_acc[0]               = acc_0;
3575                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
3576                        sb_acc[2]               = acc_2;
3577
3578                        // Scales[i] corresponds to column i
3579                        const int       scale_offset = cp * 2;
3580                        const int32_t   s0           = q5sb_scales[0][scale_offset];
3581                        const int32_t   s1           = q5sb_scales[0][scale_offset + 1];
3582                        const int32x4_t block_scale  = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
3583                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
3584                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
3585
3586                        const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
3587                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
3588                        sb_acc[1]               = acc_1;
3589                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
3590                        sb_acc[3]               = acc_3;
3591
3592                        const int32_t   s2           = q5sb_scales[1][scale_offset];
3593                        const int32_t   s3           = q5sb_scales[1][scale_offset + 1];
3594                        const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
3595                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
3596                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
3597                    }
3598
3599                    // Multiply Acc bsum + mins
3600                    for (int q8_row = 0; q8_row < 4; q8_row++) {
3601                        // Each pair of subblocks share the same bsums
3602                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
3603                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
3604                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
3605
3606                        bias_acc[2 * q8_row] =
3607                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
3608                        bias_acc[2 * q8_row] =
3609                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
3610                        bias_acc[2 * q8_row + 1] =
3611                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
3612                        bias_acc[2 * q8_row + 1] =
3613                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
3614                    }
3615                }  // for sb
3616
3617                // Reorder of i8mm output with bias and output layout
3618                for (int i = 0; i < 8; i++) {
3619                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
3620                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
3621                }
3622                int32x4_t reorder_acc[8] = {
3623                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
3624                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
3625                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
3626                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
3627                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
3628                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
3629                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
3630                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
3631                };
3632
3633                for (int i = 0; i < q8_k_blocklen; i++) {
3634                    for (int j = 0; j < 2; j++) {
3635                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
3636                        float32x4_t       q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
3637                        const float32x4_t dmins   = vmulq_f32(q5_dmin, q8_d);
3638
3639                        float32x4_t       q5_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
3640                        const float32x4_t scale = vmulq_f32(q5_d, q8_d);
3641
3642                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
3643                        acc_f32[2 * i + j] =
3644                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
3645                    }
3646                }
3647            }  // for b
3648
3649            // With the previous reorder, the tile is already in the correct memory layout.
3650            for (int i = 0; i < q8_k_blocklen; i++) {
3651                int row = y * q8_k_blocklen + i;
3652                for (int j = 0; j < 2; j++) {
3653                    int col    = x * ncols_interleaved + j * 4;
3654                    int offset = row * bs + col;
3655                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
3656                }
3657            }
3658        }  // for x
3659    }  // for y
3660    return;
3661#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3662    ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3663}
3664
3665void ggml_gemm_q6_K_8x4_q8_K(int                        n,
3666                             float * GGML_RESTRICT      s,
3667                             size_t                     bs,
3668                             const void * GGML_RESTRICT vx,
3669                             const void * GGML_RESTRICT vy,
3670                             int                        nr,
3671                             int                        nc) {
3672    constexpr int qk = QK_K;
3673    const int     nb = n / qk;
3674
3675    constexpr int ncols_interleaved = 8;
3676    constexpr int blocklen          = 4;
3677
3678    assert(n % qk == 0);
3679    assert(nr % 4 == 0);
3680    assert(nc % ncols_interleaved == 0);
3681
3682    UNUSED(nb);
3683    UNUSED(ncols_interleaved);
3684    UNUSED(blocklen);
3685
3686#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3687    constexpr int    q8_k_blocklen = 4;
3688    constexpr int    col_groups    = ncols_interleaved / 4;
3689    constexpr int    acc_size      = q8_k_blocklen * col_groups;  // 4 rows, 2 column groups
3690    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
3691    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
3692    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
3693    const int8x16_t  m32s          = vdupq_n_s8(32);
3694
3695    float32x4_t acc_f32[acc_size];
3696
3697    for (int y = 0; y < nr / q8_k_blocklen; y++) {
3698        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3699
3700        for (int x = 0; x < nc / ncols_interleaved; x++) {
3701            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
3702
3703            for (int i = 0; i < acc_size; i++) {
3704                acc_f32[i] = vdupq_n_f32(0);
3705            }
3706
3707            for (int b = 0; b < nb; b++) {
3708                float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
3709                float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
3710                float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
3711
3712                float32x4_t sbd_scale_0123[q8_k_blocklen];
3713                float32x4_t sbd_scale_4567[q8_k_blocklen];
3714
3715                sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
3716                sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
3717                sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
3718                sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
3719                sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
3720                sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
3721                sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
3722                sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
3723
3724                int32x4_t acc_s32[acc_size];
3725                for (int i = 0; i < acc_size; i++) {
3726                    acc_s32[i] = vdupq_n_s32(0);
3727                }
3728
3729                int16_t q6_scales[8 * 16];
3730                for (int i = 0; i < 16; i++) {
3731                    int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
3732                    vst1q_s16(q6_scales + i * 8, scales);
3733                }
3734
3735                for (int half = 0; half < 2; half++) {
3736                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
3737                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
3738
3739                    for (int sb = 0; sb < QK_K / 64; sb++) {
3740                        int32x4_t acc_lo[acc_size];
3741                        int32x4_t acc_hi[acc_size];
3742                        for (int i = 0; i < acc_size; i++) {
3743                            acc_lo[i] = vdupq_n_s32(0);
3744                            acc_hi[i] = vdupq_n_s32(0);
3745                        }
3746
3747                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
3748                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
3749
3750                        // 4 rows * 16 elements per scale
3751                        // 4 reads of 16 bytes each
3752                        constexpr int reads_per_sb = 4;
3753                        int8x16_t     q8_l[reads_per_sb];
3754                        int8x16_t     q8_h[reads_per_sb];
3755                        for (int k = 0; k < reads_per_sb; k++) {
3756                            q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
3757                            q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
3758                        }
3759
3760                        const int ql_off_base = sb * QK_K / 2;
3761                        const int qh_off_base = ql_off_base & 255;
3762
3763                        uint8x16_t q6_ql_0123[reads_per_sb];
3764                        uint8x16_t q6_ql_4567[reads_per_sb];
3765                        uint8x16_t q6_qh_0123[reads_per_sb];
3766                        uint8x16_t q6_qh_4567[reads_per_sb];
3767
3768                        for (int k = 0; k < reads_per_sb; k++) {
3769                            q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
3770                            q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
3771                            q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
3772                            q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
3773                        }
3774
3775                        if (sb > 1) {
3776                            for (int k = 0; k < reads_per_sb; k++) {
3777                                q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
3778                                q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
3779                            }
3780                        }
3781
3782                        for (int k = 0; k < reads_per_sb; k++) {
3783                            // q = (ql | qh) - 32
3784                            const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
3785                            const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
3786                            const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
3787                            const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
3788
3789                            const int8x16_t q6_0123_lo = vsubq_s8(
3790                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
3791                            const int8x16_t q6_0123_hi = vsubq_s8(
3792                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
3793
3794                            acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0);  //  0..3  r0 c0123
3795                            acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1);  //  0..3  r1 c0123
3796                            acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2);  //  0..3  r2 c0123
3797                            acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3);  //  0..3  r3 c0123
3798
3799                            acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0);  // 64..67 r0 c0123
3800                            acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1);  // 64..67 r1 c0123
3801                            acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2);  // 64..67 r2 c0123
3802                            acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3);  // 64..67 r3 c0123
3803
3804                            const int8x16_t q6_4567_lo = vsubq_s8(
3805                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
3806                            const int8x16_t q6_4567_hi = vsubq_s8(
3807                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
3808
3809                            acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0);  //  0..3  r0 c4567
3810                            acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1);  //  0..3  r1 c4567
3811                            acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2);  //  0..3  r2 c4567
3812                            acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3);  //  0..3  r3 c4567
3813
3814                            acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0);  // 64..67 r0 c4567
3815                            acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1);  // 64..67 r1 c4567
3816                            acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2);  // 64..67 r2 c4567
3817                            acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3);  // 64..67 r3 c4567
3818                        }
3819
3820                        // Scale and bias
3821                        const int scale_idx_l = half * 8 + sb;
3822                        const int scale_idx_h = half * 8 + sb + 4;
3823
3824                        for (int g = 0; g < col_groups; g++) {
3825                            const int16x4_t scales_l16  = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
3826                            const int16x4_t scales_h16  = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
3827                            const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
3828                            const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
3829                            const int       acc_offset  = g * q8_k_blocklen;
3830
3831                            for (int row = 0; row < q8_k_blocklen; row++) {
3832                                const int idx = row * 2 + g;
3833                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
3834                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
3835                            }
3836                        }
3837                    }
3838                }
3839
3840                // Finally we apply the superblock scales
3841                for (int row = 0; row < q8_k_blocklen; row++) {
3842                    const int       idx0     = 2 * row;
3843                    const int       idx1     = 2 * row + 1;
3844                    const int32x4_t acc_0123 = acc_s32[idx0];
3845                    const int32x4_t acc_4567 = acc_s32[idx1];
3846
3847                    acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
3848                    acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
3849                }
3850            }  // for b
3851
3852            for (int i = 0; i < q8_k_blocklen; i++) {
3853                int row = y * q8_k_blocklen + i;
3854                for (int j = 0; j < 2; j++) {
3855                    int col    = x * ncols_interleaved + j * 4;
3856                    int offset = row * bs + col;
3857                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
3858                }
3859            }
3860        }  // for x
3861    }  // for y
3862    return;
3863#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3864    ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
3865}
3866
3867void ggml_gemm_q6_K_8x8_q8_K(int                        n,
3868                             float * GGML_RESTRICT      s,
3869                             size_t                     bs,
3870                             const void * GGML_RESTRICT vx,
3871                             const void * GGML_RESTRICT vy,
3872                             int                        nr,
3873                             int                        nc) {
3874    constexpr int qk = QK_K;
3875    const int     nb = n / qk;
3876
3877    constexpr int ncols_interleaved = 8;
3878    constexpr int blocklen          = 8;
3879
3880    assert(n % qk == 0);
3881    assert(nr % 4 == 0);
3882    assert(nc % ncols_interleaved == 0);
3883
3884    UNUSED(nb);
3885    UNUSED(ncols_interleaved);
3886    UNUSED(blocklen);
3887
3888#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
3889    constexpr int    q8_k_blocklen = 4;
3890    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
3891    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
3892    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
3893    const int8x16_t  m32s          = vdupq_n_s8(32);
3894
3895    // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
3896    float32x4_t acc_f32[blocklen];
3897
3898    for (int y = 0; y < nr / q8_k_blocklen; y++) {
3899        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
3900
3901        for (int x = 0; x < nc / ncols_interleaved; x++) {
3902            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
3903
3904            for (int i = 0; i < blocklen; i++) {
3905                acc_f32[i] = vdupq_n_f32(0);
3906            }
3907
3908            for (int b = 0; b < nb; b++) {
3909                int32x4_t acc[8];  // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
3910                for (int i = 0; i < 8; i++) {
3911                    acc[i] = vdupq_n_s32(0);
3912                }
3913
3914                // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
3915                // Reused for bias and dequantization later
3916                int16_t q6_scales[16 * 8];
3917                for (int i = 0; i < 16; ++i) {
3918                    int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
3919                    vst1q_s16(q6_scales + i * 8, s16);
3920                }
3921
3922                // Process two 128-value halves per superblock
3923                for (int half = 0; half < 2; half++) {
3924
3925                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
3926                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
3927
3928                    // A subblock (sb) is a set of weights that share the scale
3929                    // Since q6_K scales are per 16 elements
3930                    // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
3931                    for (int sb = 0; sb < QK_K / 64; sb++) {
3932                        // Q6_K weight index increasing by 64 instead of 32 requires
3933                        // loading various q8 memory regions
3934                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
3935                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
3936
3937                        int8x16_t q8_l_01[2];
3938                        int8x16_t q8_l_23[2];
3939                        for (int i = 0; i < 2; i++) {
3940                            const int offset = i * 32;
3941                            q8_l_01[i]       = vld1q_s8(q8_base_l + offset);       // 0..7 & 8..15 (r01)
3942                            q8_l_23[i]       = vld1q_s8(q8_base_l + offset + 16);  // 0..7 & 8..15 (r23)
3943                        }
3944
3945                        int8x16_t q8_h_01[2];
3946                        int8x16_t q8_h_23[2];
3947                        for (int i = 0; i < 2; i++) {
3948                            const int offset = i * 32;
3949                            q8_h_01[i]       = vld1q_s8(q8_base_h + offset);
3950                            q8_h_23[i]       = vld1q_s8(q8_base_h + offset + 16);
3951                        }
3952
3953                        const int ql_off_base = sb * QK_K / 2;
3954
3955                        uint8x16_t q6_ql_0[4];
3956                        uint8x16_t q6_ql_1[4];
3957                        for (int k = 0; k < 4; k++) {
3958                            q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
3959                            q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
3960                        }
3961
3962                        const int  qh_off_base = (sb * QK_K / 2) & 255;  // wrap after 256 bytes
3963                        uint8x16_t q6_qh_0[4];
3964                        uint8x16_t q6_qh_1[4];
3965                        for (int k = 0; k < 4; k++) {
3966                            q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
3967                            q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
3968                        }
3969
3970                        // Adjust for the proper high bits (Sb 2 and 3)
3971                        if (sb > 1) {
3972                            for (int k = 0; k < 4; k++) {
3973                                q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
3974                                q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
3975                            }
3976                        }
3977
3978                        // Process column pairs (0-1, 2-3, 4-5, 6-7)
3979                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
3980                            const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
3981                            const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
3982                            const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
3983                            const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
3984
3985                            // Extract high 2 bits for upper nibble reconstruction
3986                            const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
3987                            const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
3988
3989                            // q6 = (low4 | high2<<4) - 32
3990                            // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
3991                            const int8x16_t q6_l0 = vsubq_s8(
3992                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
3993                                m32s);
3994                            const int8x16_t q6_l1 = vsubq_s8(
3995                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
3996                                m32s);
3997                            const int8x16_t q6_h0 = vsubq_s8(
3998                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
3999                            const int8x16_t q6_h1 = vsubq_s8(
4000                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
4001
4002                            // row pair 0, base_l
4003                            int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
4004                            sb_acc_0l           = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
4005                            // row pair 0, base_h
4006                            int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
4007                            sb_acc_0h           = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
4008                            // row pair 1, base_l
4009                            int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
4010                            sb_acc_1l           = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
4011                            // row pair 1, base_h
4012                            int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
4013                            sb_acc_1h           = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
4014
4015                            const int scale_idx_l = half * 8 + sb;
4016                            const int scale_idx_h = half * 8 + sb + 4;
4017
4018                            const int32x4_t scale_vec_l = {
4019                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4020                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],
4021                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4022                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],
4023                            };
4024                            const int32x4_t scale_vec_h = {
4025                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4026                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],
4027                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4028                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],
4029                            };
4030
4031                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
4032                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
4033                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
4034                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
4035                        }
4036                    }
4037                }  // for half
4038
4039                // Reorder i8mm output to match memory layout
4040                for (int i = 0; i < 8; i++) {
4041                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
4042                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
4043                }
4044                int32x4_t reorder_acc[8] = {
4045                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
4046                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
4047                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
4048                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
4049                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
4050                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
4051                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
4052                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
4053                };
4054
4055                // Apply superblock scale (no mins for q6_K)
4056                for (int i = 0; i < q8_k_blocklen; i++) {
4057                    for (int j = 0; j < 2; j++) {
4058                        float32x4_t       q8_d  = vdupq_n_f32(q8_ptr[b].d[i]);
4059                        float32x4_t       q6_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
4060                        const float32x4_t scale = vmulq_f32(q6_d, q8_d);
4061
4062                        acc_f32[2 * i + j] =
4063                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
4064                    }
4065                }
4066            }  // for b
4067
4068            // Store results
4069            for (int i = 0; i < q8_k_blocklen; i++) {
4070                int row = y * q8_k_blocklen + i;
4071                for (int j = 0; j < 2; j++) {
4072                    int col    = x * ncols_interleaved + j * 4;
4073                    int offset = row * bs + col;
4074                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
4075                }
4076            }
4077        }  // for x
4078    }  // for y
4079    return;
4080#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4081    ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
4082}
4083
4084void ggml_gemm_q8_0_4x4_q8_0(int                        n,
4085                             float * GGML_RESTRICT      s,
4086                             size_t                     bs,
4087                             const void * GGML_RESTRICT vx,
4088                             const void * GGML_RESTRICT vy,
4089                             int                        nr,
4090                             int                        nc) {
4091    const int qk                = QK8_0;
4092    const int nb                = n / qk;
4093    const int ncols_interleaved = 4;
4094    const int blocklen          = 4;
4095
4096    assert(n % qk == 0);
4097    assert(nr % 4 == 0);
4098    assert(nc % ncols_interleaved == 0);
4099
4100    UNUSED(nb);
4101    UNUSED(ncols_interleaved);
4102    UNUSED(blocklen);
4103
4104#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4105    for (int y = 0; y < nr / 4; y++) {
4106        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
4107        for (int x = 0; x < nc / ncols_interleaved; x++) {
4108            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
4109
4110            float32x4_t sumf[4];
4111            for (int m = 0; m < 4; m++) {
4112                sumf[m] = vdupq_n_f32(0);
4113            }
4114
4115            for (int l = 0; l < nb; l++) {
4116                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
4117                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
4118
4119                int32x4_t sumi_0 = vdupq_n_s32(0);
4120                int32x4_t sumi_1 = vdupq_n_s32(0);
4121                int32x4_t sumi_2 = vdupq_n_s32(0);
4122                int32x4_t sumi_3 = vdupq_n_s32(0);
4123
4124                for (int k_group = 0; k_group < 8; k_group += 4) {
4125                    int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
4126                    int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
4127
4128                    for (int k = 0; k < 4; k++) {
4129                        sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
4130                        sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
4131                        sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
4132                        sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
4133                    }
4134                }
4135
4136                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
4137                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
4138                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
4139                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
4140            }
4141
4142            for (int m = 0; m < 4; m++) {
4143                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
4144            }
4145        }
4146    }
4147    return;
4148#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
4149    ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
4150}
4151
4152void ggml_gemm_q8_0_4x8_q8_0(int                        n,
4153                             float * GGML_RESTRICT      s,
4154                             size_t                     bs,
4155                             const void * GGML_RESTRICT vx,
4156                             const void * GGML_RESTRICT vy,
4157                             int                        nr,
4158                             int                        nc) {
4159    const int qk                = QK8_0;
4160    const int nb                = n / qk;
4161    const int ncols_interleaved = 4;
4162    const int blocklen          = 8;
4163
4164    assert(n % qk == 0);
4165    assert(nr % 4 == 0);
4166    assert(nc % ncols_interleaved == 0);
4167
4168    UNUSED(nb);
4169    UNUSED(ncols_interleaved);
4170    UNUSED(blocklen);
4171
4172#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4173    const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
4174
4175    for (int y = 0; y < nr; y += 4) {
4176        const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
4177
4178        for (int x = 0; x < nc; x += ncols_interleaved) {
4179            const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
4180            const block_q8_0x4 * a_ptr = a_ptr_base;
4181
4182            float32x4_t acc_f32[4];
4183            for (int i = 0; i < 4; i++) {
4184                acc_f32[i] = vdupq_n_f32(0);
4185            }
4186
4187            for (int b = 0; b < nb; b++) {
4188                int32x4_t acc[4];
4189                for (int i = 0; i < 4; i++) {
4190                    acc[i] = vdupq_n_s32(0);
4191                }
4192
4193                // Process 4 chunks of 8 positions each
4194                for (int chunk = 0; chunk < 4; chunk++) {
4195                    int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
4196                    int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
4197                    int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
4198                    int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
4199
4200                    acc[0] = vmmlaq_s32(acc[0], a01, b01);
4201                    acc[1] = vmmlaq_s32(acc[1], a01, b23);
4202                    acc[2] = vmmlaq_s32(acc[2], a23, b01);
4203                    acc[3] = vmmlaq_s32(acc[3], a23, b23);
4204                }
4205
4206                // Reorder outputs from 2×2 tiles to row-major
4207                // acc[0] = [r0c0, r0c1, r1c0, r1c1]
4208                // acc[1] = [r0c2, r0c3, r1c2, r1c3]
4209                // acc[2] = [r2c0, r2c1, r3c0, r3c1]
4210                // acc[3] = [r2c2, r2c3, r3c2, r3c3]
4211                int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
4212                int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
4213                int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
4214                int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
4215
4216                // Scales
4217                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
4218                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
4219
4220                acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
4221                acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
4222                acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
4223                acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
4224
4225                a_ptr++;
4226                b_ptr++;
4227            }
4228
4229            for (int row = 0; row < 4; row++) {
4230                vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
4231            }
4232        }
4233    }
4234    return;
4235#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
4236    ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
4237}