1// Copyright 2024 Mozilla Foundation
   2//
   3// Permission is hereby granted, free of charge, to any person obtaining
   4// a copy of this software and associated documentation files (the
   5// "Software"), to deal in the Software without restriction, including
   6// without limitation the rights to use, copy, modify, merge, publish,
   7// distribute, sublicense, and/or sell copies of the Software, and to
   8// permit persons to whom the Software is furnished to do so, subject to
   9// the following conditions:
  10//
  11// The above copyright notice and this permission notice shall be
  12// included in all copies or substantial portions of the Software.
  13//
  14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  18// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  19// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21// SOFTWARE.
  22
  23//
  24//                   _   _          ___ _      _   ___
  25//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
  26//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
  27//                   \__|_|_||_\_, |___/____/_/ \_\___/
  28//                             |__/
  29//
  30//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
  31//
  32//
  33// This file implements multithreaded CPU matrix multiplication for the
  34// common contiguous use case C = Aᵀ * B. These kernels are designed to
  35// have excellent performance[1] for matrices that fit in the CPU cache
  36// without imposing any overhead such as cache filling or malloc calls.
  37//
  38// This implementation does not guarantee any upper bound with rounding
  39// errors, which grow along with k. Our goal's to maximally exploit the
  40// hardware for performance, and then use whatever resources remain for
  41// improving numerical accuracy.
  42//
  43// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  44//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  45
  46#if defined(__GNUC__)
  47#pragma GCC diagnostic ignored "-Wpedantic"
  48#pragma GCC diagnostic ignored "-Wignored-attributes"
  49#endif
  50
  51#include "sgemm.h"
  52#include "ggml-impl.h"
  53#include "ggml-cpu-impl.h"
  54#include "ggml-quants.h"
  55#include "simd-mappings.h"
  56
  57#include <array>
  58#include <type_traits>
  59
  60#ifdef _MSC_VER
  61#define NOINLINE __declspec(noinline)
  62#else
  63#define NOINLINE __attribute__((__noinline__))
  64#endif
  65
  66#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
  67#define VECTOR_REGISTERS 32
  68#else
  69#define VECTOR_REGISTERS 16
  70#endif
  71
  72#if defined(__riscv_v_intrinsic)
  73#define LMUL 4
  74#endif
  75
  76#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  77
  78namespace {
  79
  80inline float unhalf(ggml_fp16_t d) {
  81    return GGML_CPU_FP16_TO_FP32(d);
  82}
  83
  84////////////////////////////////////////////////////////////////////////////////////////////////////
  85// VECTORIZED ARITHMETIC OPERATIONS
  86
  87#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  88inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  89inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  90inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  91#endif  // __SSE__
  92
  93#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  94inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  95inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  96inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  97#endif // __AVX__
  98
  99#if defined(__AVX512F__)
 100inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
 101inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
 102inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
 103#endif // __AVX512F__
 104
 105#if defined(__ARM_NEON)
 106inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
 107inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
 108inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
 109#endif // __ARM_NEON
 110
 111#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 112inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
 113inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
 114inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
 115#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 116
 117#if defined(__VXE__) || defined(__VXE2__)
 118inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
 119inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
 120inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
 121#endif
 122
 123#if defined(__MMA__)
 124#include "sgemm-ppc.h"
 125#endif
 126////////////////////////////////////////////////////////////////////////////////////////////////////
 127// VECTORIZED FUSED MULTIPLY ADD
 128
 129/**
 130 * Computes a * b + c.
 131 */
 132template <typename T, typename U>
 133inline U madd(T a, T b, U c) {
 134    return add(mul(a, b), c);
 135}
 136
 137#if defined(__FMA__)
 138#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 139template <>
 140inline __m256 madd(__m256 a, __m256 b, __m256 c) {
 141    return _mm256_fmadd_ps(a, b, c);
 142}
 143#endif
 144#if defined(__AVX512F__)
 145template <>
 146inline __m512 madd(__m512 a, __m512 b, __m512 c) {
 147    return _mm512_fmadd_ps(a, b, c);
 148}
 149#endif
 150#if defined(__AVX512BF16__)
 151template <>
 152inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
 153    return _mm512_dpbf16_ps(c, a, b);
 154}
 155template <>
 156inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
 157    return _mm256_dpbf16_ps(c, a, b);
 158}
 159#endif
 160#endif
 161
 162#if defined(__ARM_FEATURE_FMA)
 163template <>
 164inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
 165    return vfmaq_f32(c, b, a);
 166}
 167#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
 168template <>
 169inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
 170    return vfmaq_f16(c, b, a);
 171}
 172#endif
 173#endif
 174
 175#if defined(__VXE__) || defined(__VXE2__)
 176template <>
 177inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
 178    return vec_madd(a, b, c);
 179}
 180#endif
 181
 182#if defined(__riscv_zvfh)
 183template <>
 184inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
 185    return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
 186}
 187inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
 188    return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
 189}
 190inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
 191    return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
 192}
 193inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
 194    return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
 195}
 196inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
 197    return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
 198}
 199inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
 200    return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
 201}
 202inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
 203    return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
 204}
 205inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
 206    return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
 207}
 208#endif
 209
 210#if defined(__riscv_zvfbfwma)
 211inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
 212    return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
 213}
 214inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
 215    return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
 216}
 217inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
 218    return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
 219}
 220#endif
 221
 222////////////////////////////////////////////////////////////////////////////////////////////////////
 223// VECTORIZED HORIZONTAL SUM
 224
 225#if defined(__ARM_NEON)
 226inline float hsum(float32x4_t x) {
 227    return vaddvq_f32(x);
 228}
 229#endif // __ARM_NEON
 230
 231#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
 232inline float hsum(float16x8_t x) {
 233    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
 234                                vcvt_f32_f16(vget_high_f16(x))));
 235}
 236#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 237
 238#if defined(__VXE__) || defined(__VXE2__)
 239inline float hsum(float32x4_t x) {
 240    float32x4_t tmp = x + vec_reve(x);
 241    return tmp[0] + tmp[1];
 242}
 243#endif
 244
 245#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 246inline float hsum(__m128 x) {
 247#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 248    x = _mm_add_ps(x, _mm_movehl_ps(x, x));
 249    x = _mm_add_ss(x, _mm_movehdup_ps(x));
 250#else
 251    __m128 t;
 252    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
 253    x = _mm_add_ps(x, t);
 254    t = _mm_movehl_ps(t, x);
 255    x = _mm_add_ss(x, t);
 256#endif
 257    return _mm_cvtss_f32(x);
 258}
 259#endif
 260
 261#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 262inline float hsum(__m256 x) {
 263    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
 264                           _mm256_castps256_ps128(x)));
 265}
 266#endif // __AVX__
 267
 268#if defined(__AVX512F__)
 269inline float hsum(__m512 x) {
 270    return _mm512_reduce_add_ps(x);
 271}
 272#endif // __AVX512F__
 273
 274#if defined(__riscv_zvfh)
 275inline float hsum(vfloat32m1_t x) {
 276    return __riscv_vfmv_f_s_f32m1_f32(
 277        __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
 278}
 279inline float hsum(vfloat32m2_t x) {
 280    return __riscv_vfmv_f_s_f32m1_f32(
 281        __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
 282}
 283inline float hsum(vfloat32m4_t x) {
 284    return __riscv_vfmv_f_s_f32m1_f32(
 285        __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
 286}
 287inline float hsum(vfloat32m8_t x) {
 288    return __riscv_vfmv_f_s_f32m1_f32(
 289        __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
 290}
 291#endif
 292
 293////////////////////////////////////////////////////////////////////////////////////////////////////
 294// VECTORIZED MEMORY LOADING
 295
 296template <typename T, typename U> T load(const U *);
 297
 298#if defined(__ARM_NEON)
 299template <> inline float32x4_t load(const float *p) {
 300    return vld1q_f32(p);
 301}
 302#if !defined(_MSC_VER)
 303// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 304template <> inline float16x8_t load(const ggml_fp16_t *p) {
 305    return vld1q_f16((const float16_t *)p);
 306}
 307template <> inline float32x4_t load(const ggml_fp16_t *p) {
 308    return vcvt_f32_f16(vld1_f16((const float16_t *)p));
 309}
 310#endif // _MSC_VER
 311#endif // __ARM_NEON
 312
 313#if defined(__VXE__) || defined(__VXE2__)
 314template <> inline float32x4_t load(const ggml_fp16_t * p) {
 315    float tmp[4];
 316
 317    for (int i = 0; i < 4; i++) {
 318        tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
 319    }
 320
 321    return vec_xl(0, (const float *)(tmp));
 322}
 323template <> inline float32x4_t load(const float * p) {
 324    return vec_xl(0, p);
 325}
 326#endif
 327
 328#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 329template <> inline __m128 load(const float *p) {
 330    return _mm_loadu_ps(p);
 331}
 332#endif  // __SSE__
 333
 334#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 335template <> inline __m256 load(const float *p) {
 336    return _mm256_loadu_ps(p);
 337}
 338#endif // __AVX__
 339
 340#if defined(__AVX2__) || defined(__AVX512F__)
 341template <> inline __m256 load(const ggml_bf16_t *p) {
 342    return _mm256_castsi256_ps(
 343        _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
 344}
 345#endif // __AVX2__
 346
 347#if defined(__F16C__)
 348template <> inline __m256 load(const ggml_fp16_t *p) {
 349    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
 350}
 351#endif // __F16C__
 352
 353#if defined(__AVX512F__)
 354template <> inline __m512 load(const float *p) {
 355    return _mm512_loadu_ps(p);
 356}
 357template <> inline __m512 load(const ggml_fp16_t *p) {
 358    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
 359}
 360template <> inline __m512 load(const ggml_bf16_t *p) {
 361    return _mm512_castsi512_ps(
 362        _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
 363}
 364#endif // __AVX512F__
 365
 366#if defined(__AVX512BF16__)
 367template <> inline __m512bh load(const ggml_bf16_t *p) {
 368    return (__m512bh)_mm512_loadu_ps((const float *)p);
 369}
 370template <> inline __m256bh load(const ggml_bf16_t *p) {
 371    return (__m256bh)_mm256_loadu_ps((const float *)p);
 372}
 373template <> inline __m512bh load(const float *p) {
 374    return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
 375}
 376template <> inline __m256bh load(const float *p) {
 377    return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
 378}
 379#endif
 380
 381#if defined(__riscv_zvfh)
 382template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
 383    return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
 384}
 385template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
 386    return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
 387}
 388template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
 389    return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
 390}
 391template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
 392    return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
 393}
 394template <> inline vfloat32m1_t load(const float *p) {
 395    return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
 396}
 397template <> inline vfloat32m2_t load(const float *p) {
 398    return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
 399}
 400template <> inline vfloat32m4_t load(const float *p) {
 401    return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
 402}
 403template <> inline vfloat32m8_t load(const float *p) {
 404    return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
 405}
 406#endif
 407
 408#if defined(__riscv_zvfbfwma)
 409template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
 410    return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
 411}
 412template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
 413    return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
 414}
 415template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
 416    return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
 417}
 418#endif
 419
 420#if defined(__riscv_zvfh)
 421template <typename T> T set_zero();
 422
 423template <> inline vfloat16mf2_t set_zero() {
 424    return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
 425}
 426template <> inline vfloat16m1_t set_zero() {
 427    return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
 428}
 429template <> inline vfloat16m2_t set_zero() {
 430    return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
 431}
 432template <> inline vfloat16m4_t set_zero() {
 433    return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
 434}
 435template <> inline vfloat32m1_t set_zero() {
 436    return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
 437}
 438template <> inline vfloat32m2_t set_zero() {
 439    return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
 440}
 441template <> inline vfloat32m4_t set_zero() {
 442    return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
 443}
 444template <> inline vfloat32m8_t set_zero() {
 445    return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
 446}
 447#endif
 448
 449#if defined(__riscv_v_intrinsic)
 450template <typename T> size_t vlmax() {
 451    if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return  __riscv_vsetvlmax_e16mf2(); }
 452    else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return  __riscv_vsetvlmax_e16m1(); }
 453    else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return  __riscv_vsetvlmax_e16m2(); }
 454    else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return  __riscv_vsetvlmax_e16m4(); }
 455    else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return  __riscv_vsetvlmax_e32m1(); }
 456    else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return  __riscv_vsetvlmax_e32m2(); }
 457    else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return  __riscv_vsetvlmax_e32m4(); }
 458    else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return  __riscv_vsetvlmax_e32m8(); }
 459    return 0;
 460}
 461#endif
 462
 463////////////////////////////////////////////////////////////////////////////////////////////////////
 464// FLOATING POINT MATRIX MULTIPLICATION
 465
 466template <int M>
 467static inline int64_t BLOCK_SIZE(size_t m) {
 468    const int64_t NB_BLOC_M = (m + M - 1) / M;
 469    return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
 470}
 471
 472static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
 473    return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
 474}
 475
 476template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
 477class tinyBLAS {
 478  public:
 479    tinyBLAS(const ggml_compute_params * params, int64_t k,
 480             const TA *A, int64_t lda,
 481             const TB *B, int64_t ldb,
 482             TC *C, int64_t ldc)
 483        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
 484    }
 485
 486    bool matmul(int64_t m, int64_t n) {
 487        if (k % KN != 0)
 488            return false;
 489        // compute RM for only need tile with size RM&RM-1
 490#if VECTOR_REGISTERS == 32
 491        if (m % 16 == 0 && (m/16 >= params->nth)) {
 492            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
 493            mnpack<4, 6, 4>(m, n, SIZE_N, 12);
 494            return true;
 495        }
 496        if (m % 8 == 0 ) {
 497            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
 498            mnpack<4, 6, 2>(m, n, SIZE_N, 12);
 499            return true;
 500        }
 501        if (m % 4 == 0) {
 502            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
 503            mnpack<4, 6, 1>(m, n, SIZE_N, 12);
 504            return true;
 505        }
 506#else  // VECTOR_REGISTERS == 16
 507        if (m % 16 == 0 && (m/16 >= params->nth)) {
 508            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
 509            mnpack<4, 3, 4>(m, n, SIZE_N, 24);
 510            return true;
 511        }
 512        if (m % 8 == 0 ) {
 513            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
 514            mnpack<4, 3, 2>(m, n, SIZE_N, 24);
 515            return true;
 516        }
 517        if (m % 4 == 0) {
 518            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
 519            mnpack<4, 3, 1>(m, n, SIZE_N, 24);
 520            return true;
 521        }
 522#endif
 523        return false;
 524    }
 525
 526  private:
 527    template <int RM, int RN, int BM>
 528    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
 529        if (SIZE_N == RN) {
 530            return gemm<RM, RN, BM>(m, n, BN);
 531        }
 532        if constexpr (RN > 1) {
 533            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
 534        } else {
 535            GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
 536            GGML_ASSERT(false); // we have miss something.
 537        }
 538    }
 539
 540    template <int RM, int RN>
 541    inline void gemm_bloc(int64_t ii, int64_t jj) {
 542        D Cv[RN][RM] = {};
 543        for (int64_t l = 0; l < k; l += KN) {
 544            // help compiler for op order.
 545            if constexpr (RM <= RN) {
 546                V Av[RM];
 547                for (int64_t i = 0; i < RM; ++i) {
 548                    Av[i] = load<V>(A + lda * (ii + i) + l);
 549                }
 550                for (int64_t j = 0; j < RN; ++j) {
 551                    V Bv = load<V>(B + ldb * (jj + j) + l);
 552                    for (int64_t i = 0; i < RM; ++i) {
 553                        Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
 554                    }
 555                }
 556            } else {
 557                V Bv[RN];
 558                for (int64_t j = 0; j < RN; ++j) {
 559                    Bv[j] = load<V>(B + ldb * (jj + j) + l);
 560                }
 561                for (int64_t i = 0; i < RM; ++i) {
 562                    V Av = load<V>(A + lda * (ii + i) + l);
 563                    for (int64_t j = 0; j < RN; ++j) {
 564                        Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
 565                    }
 566                }
 567            }
 568        }
 569        for (int64_t j = 0; j < RN; ++j)
 570            for (int64_t i = 0; i < RM; ++i)
 571                C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
 572    }
 573
 574    template <int RM, int RN, int BM>
 575    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
 576        GGML_ASSERT(m % (RM * BM) == 0);
 577        const int64_t ytiles = m / (RM * BM);
 578        const int64_t xtiles = (n + RN -1) / RN;
 579        const int64_t jj_RN = (xtiles - (xtiles * RN - n));
 580
 581        // "round" bloc_size to "nearest" BN
 582        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
 583        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
 584        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
 585        const int64_t nb_job = ytiles * NB_BN;
 586
 587        if (params->ith == 0) {
 588            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
 589            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
 590            ggml_threadpool_chunk_set(params->threadpool, params->nth);
 591        }
 592
 593        ggml_barrier(params->threadpool);
 594
 595        int64_t job = params->ith;
 596        while (job < nb_job) {
 597            const int64_t ii = (job % ytiles) * RM * BM;
 598            const int64_t jb =  job / ytiles;
 599            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);
 600            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
 601
 602            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
 603            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
 604            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
 605
 606            for (int64_t bi = 0; bi < BM * RM; bi += RM) {
 607                int64_t jj = jj0;
 608                for (; jj < jj1; jj += RN) {
 609                    gemm_bloc<RM, RN>(ii + bi, jj);
 610                }
 611                if constexpr (RN > 1) {
 612                    for (; jj < jj2; jj += RN - 1) {
 613                        gemm_bloc<RM, RN-1>(ii + bi, jj);
 614                    }
 615                }
 616                GGML_ASSERT(jj == jj2);
 617            }
 618
 619            job = ggml_threadpool_chunk_add(params->threadpool, 1);
 620        }
 621
 622        ggml_barrier(params->threadpool);
 623        return;
 624    }
 625
 626    const ggml_compute_params * params;
 627    const TA *const A;
 628    const TB *const B;
 629    TC *const C;
 630    const int64_t k;
 631    const int64_t lda;
 632    const int64_t ldb;
 633    const int64_t ldc;
 634};
 635
 636#if defined(__riscv_v_intrinsic)
 637template <typename D, typename V, typename TA, typename TB, typename TC>
 638class tinyBLAS_RVV {
 639  public:
 640    tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
 641             const TA *A, int64_t lda,
 642             const TB *B, int64_t ldb,
 643             TC *C, int64_t ldc)
 644        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
 645    }
 646
 647    bool matmul(int64_t m, int64_t n) {
 648        if (k % vlmax<V>() != 0) {
 649            return false;
 650        }
 651
 652#if LMUL == 1
 653        if (m % 16 == 0 && (m/16 >= params->nth)) {
 654            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
 655            mnpack<4, 6, 4>(m, n, SIZE_N, 12);
 656            return true;
 657        }
 658        if (m % 8 == 0 ) {
 659            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
 660            mnpack<4, 6, 2>(m, n, SIZE_N, 12);
 661            return true;
 662        }
 663        if (m % 4 == 0) {
 664            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
 665            mnpack<4, 6, 1>(m, n, SIZE_N, 12);
 666            return true;
 667        }
 668#elif LMUL == 2
 669        if (m % 16 == 0 && (m/16 >= params->nth)) {
 670            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
 671            mnpack<4, 3, 4>(m, n, SIZE_N, 24);
 672            return true;
 673        }
 674        if (m % 8 == 0 ) {
 675            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
 676            mnpack<4, 3, 2>(m, n, SIZE_N, 24);
 677            return true;
 678        }
 679        if (m % 4 == 0) {
 680            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
 681            mnpack<4, 3, 1>(m, n, SIZE_N, 24);
 682            return true;
 683        }
 684#else // LMUL = 4
 685        if (m % 16 == 0 && (m/16 >= params->nth)) {
 686            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
 687            mnpack<2, 2, 8>(m, n, SIZE_N, 36);
 688            return true;
 689        }
 690        if (m % 8 == 0 ) {
 691            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
 692            mnpack<2, 2, 4>(m, n, SIZE_N, 36);
 693            return true;
 694        }
 695        if (m % 4 == 0) {
 696            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
 697            mnpack<2, 2, 2>(m, n, SIZE_N, 36);
 698            return true;
 699        }
 700#endif
 701        return false;
 702    }
 703
 704  private:
 705    template<int RM, int RN, int BM>
 706    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
 707        if (SIZE_N == RN) {
 708            return gemm<RM, RN, BM>(m, n, BN);
 709        }
 710        if constexpr (RN > 1) {
 711            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
 712        } else {
 713            GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
 714            GGML_ASSERT(false); // we have miss something.
 715        }
 716    }
 717
 718    inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
 719        size_t vl = vlmax<V>();
 720        D Cv00 = set_zero<D>();
 721        D Cv01 = set_zero<D>();
 722        D Cv02 = set_zero<D>();
 723        D Cv03 = set_zero<D>();
 724        D Cv10 = set_zero<D>();
 725        D Cv11 = set_zero<D>();
 726        D Cv12 = set_zero<D>();
 727        D Cv13 = set_zero<D>();
 728        D Cv20 = set_zero<D>();
 729        D Cv21 = set_zero<D>();
 730        D Cv22 = set_zero<D>();
 731        D Cv23 = set_zero<D>();
 732        D Cv30 = set_zero<D>();
 733        D Cv31 = set_zero<D>();
 734        D Cv32 = set_zero<D>();
 735        D Cv33 = set_zero<D>();
 736        D Cv40 = set_zero<D>();
 737        D Cv41 = set_zero<D>();
 738        D Cv42 = set_zero<D>();
 739        D Cv43 = set_zero<D>();
 740        D Cv50 = set_zero<D>();
 741        D Cv51 = set_zero<D>();
 742        D Cv52 = set_zero<D>();
 743        D Cv53 = set_zero<D>();
 744
 745        for (int64_t l = 0; l < k; l += vl) {
 746            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
 747            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
 748            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
 749            V Bv3 = load<V>(B + ldb * (jj + 3) + l);
 750            V Bv4 = load<V>(B + ldb * (jj + 4) + l);
 751            V Bv5 = load<V>(B + ldb * (jj + 5) + l);
 752
 753            V Av0 = load<V>(A + lda * (ii + 0) + l);
 754            Cv00 = madd(Av0, Bv0, Cv00);
 755            Cv10 = madd(Av0, Bv1, Cv10);
 756            Cv20 = madd(Av0, Bv2, Cv20);
 757            Cv30 = madd(Av0, Bv3, Cv30);
 758            Cv40 = madd(Av0, Bv4, Cv40);
 759            Cv50 = madd(Av0, Bv5, Cv50);
 760
 761            V Av1 = load<V>(A + lda * (ii + 1) + l);
 762            Cv01 = madd(Av1, Bv0, Cv01);
 763            Cv11 = madd(Av1, Bv1, Cv11);
 764            Cv21 = madd(Av1, Bv2, Cv21);
 765            Cv31 = madd(Av1, Bv3, Cv31);
 766            Cv41 = madd(Av1, Bv4, Cv41);
 767            Cv51 = madd(Av1, Bv5, Cv51);
 768
 769            V Av2 = load<V>(A + lda * (ii + 2) + l);
 770            Cv02 = madd(Av2, Bv0, Cv02);
 771            Cv12 = madd(Av2, Bv1, Cv12);
 772            Cv22 = madd(Av2, Bv2, Cv22);
 773            Cv32 = madd(Av2, Bv3, Cv32);
 774            Cv42 = madd(Av2, Bv4, Cv42);
 775            Cv52 = madd(Av2, Bv5, Cv52);
 776
 777            V Av3 = load<V>(A + lda * (ii + 3) + l);
 778            Cv03 = madd(Av3, Bv0, Cv03);
 779            Cv13 = madd(Av3, Bv1, Cv13);
 780            Cv23 = madd(Av3, Bv2, Cv23);
 781            Cv33 = madd(Av3, Bv3, Cv33);
 782            Cv43 = madd(Av3, Bv4, Cv43);
 783            Cv53 = madd(Av3, Bv5, Cv53);
 784        }
 785
 786        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
 787        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
 788        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
 789        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
 790        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
 791        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
 792        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
 793        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
 794        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
 795        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
 796        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
 797        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
 798        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
 799        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
 800        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
 801        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
 802        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
 803        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
 804        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
 805        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
 806        C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
 807        C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
 808        C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
 809        C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
 810    }
 811
 812    inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
 813        size_t vl = vlmax<V>();
 814        D Cv00 = set_zero<D>();
 815        D Cv01 = set_zero<D>();
 816        D Cv02 = set_zero<D>();
 817        D Cv03 = set_zero<D>();
 818        D Cv10 = set_zero<D>();
 819        D Cv11 = set_zero<D>();
 820        D Cv12 = set_zero<D>();
 821        D Cv13 = set_zero<D>();
 822        D Cv20 = set_zero<D>();
 823        D Cv21 = set_zero<D>();
 824        D Cv22 = set_zero<D>();
 825        D Cv23 = set_zero<D>();
 826        D Cv30 = set_zero<D>();
 827        D Cv31 = set_zero<D>();
 828        D Cv32 = set_zero<D>();
 829        D Cv33 = set_zero<D>();
 830        D Cv40 = set_zero<D>();
 831        D Cv41 = set_zero<D>();
 832        D Cv42 = set_zero<D>();
 833        D Cv43 = set_zero<D>();
 834
 835        for (int64_t l = 0; l < k; l += vl) {
 836            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
 837            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
 838            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
 839            V Bv3 = load<V>(B + ldb * (jj + 3) + l);
 840            V Bv4 = load<V>(B + ldb * (jj + 4) + l);
 841
 842            V Av0 = load<V>(A + lda * (ii + 0) + l);
 843            Cv00 = madd(Av0, Bv0, Cv00);
 844            Cv10 = madd(Av0, Bv1, Cv10);
 845            Cv20 = madd(Av0, Bv2, Cv20);
 846            Cv30 = madd(Av0, Bv3, Cv30);
 847            Cv40 = madd(Av0, Bv4, Cv40);
 848
 849            V Av1 = load<V>(A + lda * (ii + 1) + l);
 850            Cv01 = madd(Av1, Bv0, Cv01);
 851            Cv11 = madd(Av1, Bv1, Cv11);
 852            Cv21 = madd(Av1, Bv2, Cv21);
 853            Cv31 = madd(Av1, Bv3, Cv31);
 854            Cv41 = madd(Av1, Bv4, Cv41);
 855
 856            V Av2 = load<V>(A + lda * (ii + 2) + l);
 857            Cv02 = madd(Av2, Bv0, Cv02);
 858            Cv12 = madd(Av2, Bv1, Cv12);
 859            Cv22 = madd(Av2, Bv2, Cv22);
 860            Cv32 = madd(Av2, Bv3, Cv32);
 861            Cv42 = madd(Av2, Bv4, Cv42);
 862
 863            V Av3 = load<V>(A + lda * (ii + 3) + l);
 864            Cv03 = madd(Av3, Bv0, Cv03);
 865            Cv13 = madd(Av3, Bv1, Cv13);
 866            Cv23 = madd(Av3, Bv2, Cv23);
 867            Cv33 = madd(Av3, Bv3, Cv33);
 868            Cv43 = madd(Av3, Bv4, Cv43);
 869        }
 870
 871        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
 872        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
 873        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
 874        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
 875        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
 876        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
 877        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
 878        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
 879        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
 880        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
 881        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
 882        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
 883        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
 884        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
 885        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
 886        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
 887        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
 888        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
 889        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
 890        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
 891    }
 892
 893    inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
 894        size_t vl = vlmax<V>();
 895        D Cv00 = set_zero<D>();
 896        D Cv01 = set_zero<D>();
 897        D Cv02 = set_zero<D>();
 898        D Cv03 = set_zero<D>();
 899        D Cv10 = set_zero<D>();
 900        D Cv11 = set_zero<D>();
 901        D Cv12 = set_zero<D>();
 902        D Cv13 = set_zero<D>();
 903        D Cv20 = set_zero<D>();
 904        D Cv21 = set_zero<D>();
 905        D Cv22 = set_zero<D>();
 906        D Cv23 = set_zero<D>();
 907        D Cv30 = set_zero<D>();
 908        D Cv31 = set_zero<D>();
 909        D Cv32 = set_zero<D>();
 910        D Cv33 = set_zero<D>();
 911
 912        for (int64_t l = 0; l < k; l += vl) {
 913            V Av0 = load<V>(A + lda * (ii + 0) + l);
 914            V Av1 = load<V>(A + lda * (ii + 1) + l);
 915            V Av2 = load<V>(A + lda * (ii + 2) + l);
 916            V Av3 = load<V>(A + lda * (ii + 3) + l);
 917
 918            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
 919            Cv00 = madd(Av0, Bv0, Cv00);
 920            Cv01 = madd(Av1, Bv0, Cv01);
 921            Cv02 = madd(Av2, Bv0, Cv02);
 922            Cv03 = madd(Av3, Bv0, Cv03);
 923
 924            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
 925            Cv10 = madd(Av0, Bv1, Cv10);
 926            Cv11 = madd(Av1, Bv1, Cv11);
 927            Cv12 = madd(Av2, Bv1, Cv12);
 928            Cv13 = madd(Av3, Bv1, Cv13);
 929
 930            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
 931            Cv20 = madd(Av0, Bv2, Cv20);
 932            Cv21 = madd(Av1, Bv2, Cv21);
 933            Cv22 = madd(Av2, Bv2, Cv22);
 934            Cv23 = madd(Av3, Bv2, Cv23);
 935
 936            V Bv3 = load<V>(B + ldb * (jj + 3) + l);
 937            Cv30 = madd(Av0, Bv3, Cv30);
 938            Cv31 = madd(Av1, Bv3, Cv31);
 939            Cv32 = madd(Av2, Bv3, Cv32);
 940            Cv33 = madd(Av3, Bv3, Cv33);
 941        }
 942
 943        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
 944        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
 945        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
 946        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
 947        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
 948        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
 949        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
 950        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
 951        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
 952        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
 953        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
 954        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
 955        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
 956        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
 957        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
 958        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
 959    }
 960
 961    inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
 962        size_t vl = vlmax<V>();
 963        D Cv00 = set_zero<D>();
 964        D Cv01 = set_zero<D>();
 965        D Cv02 = set_zero<D>();
 966        D Cv03 = set_zero<D>();
 967        D Cv10 = set_zero<D>();
 968        D Cv11 = set_zero<D>();
 969        D Cv12 = set_zero<D>();
 970        D Cv13 = set_zero<D>();
 971        D Cv20 = set_zero<D>();
 972        D Cv21 = set_zero<D>();
 973        D Cv22 = set_zero<D>();
 974        D Cv23 = set_zero<D>();
 975
 976        for (int64_t l = 0; l < k; l += vl) {
 977            V Av0 = load<V>(A + lda * (ii + 0) + l);
 978            V Av1 = load<V>(A + lda * (ii + 1) + l);
 979            V Av2 = load<V>(A + lda * (ii + 2) + l);
 980            V Av3 = load<V>(A + lda * (ii + 3) + l);
 981
 982            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
 983            Cv00 = madd(Av0, Bv0, Cv00);
 984            Cv01 = madd(Av1, Bv0, Cv01);
 985            Cv02 = madd(Av2, Bv0, Cv02);
 986            Cv03 = madd(Av3, Bv0, Cv03);
 987
 988            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
 989            Cv10 = madd(Av0, Bv1, Cv10);
 990            Cv11 = madd(Av1, Bv1, Cv11);
 991            Cv12 = madd(Av2, Bv1, Cv12);
 992            Cv13 = madd(Av3, Bv1, Cv13);
 993
 994            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
 995            Cv20 = madd(Av0, Bv2, Cv20);
 996            Cv21 = madd(Av1, Bv2, Cv21);
 997            Cv22 = madd(Av2, Bv2, Cv22);
 998            Cv23 = madd(Av3, Bv2, Cv23);
 999        }
1000
1001        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1002        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1003        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1004        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1005        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1006        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1007        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1008        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1009        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1010        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1011        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1012        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1013    }
1014
1015    inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1016        size_t vl = vlmax<V>();
1017        D Cv00 = set_zero<D>();
1018        D Cv01 = set_zero<D>();
1019        D Cv02 = set_zero<D>();
1020        D Cv03 = set_zero<D>();
1021        D Cv10 = set_zero<D>();
1022        D Cv11 = set_zero<D>();
1023        D Cv12 = set_zero<D>();
1024        D Cv13 = set_zero<D>();
1025
1026        for (int64_t l = 0; l < k; l += vl) {
1027            V Av0 = load<V>(A + lda * (ii + 0) + l);
1028            V Av1 = load<V>(A + lda * (ii + 1) + l);
1029            V Av2 = load<V>(A + lda * (ii + 2) + l);
1030            V Av3 = load<V>(A + lda * (ii + 3) + l);
1031
1032            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1033            Cv00 = madd(Av0, Bv0, Cv00);
1034            Cv01 = madd(Av1, Bv0, Cv01);
1035            Cv02 = madd(Av2, Bv0, Cv02);
1036            Cv03 = madd(Av3, Bv0, Cv03);
1037
1038            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1039            Cv10 = madd(Av0, Bv1, Cv10);
1040            Cv11 = madd(Av1, Bv1, Cv11);
1041            Cv12 = madd(Av2, Bv1, Cv12);
1042            Cv13 = madd(Av3, Bv1, Cv13);
1043        }
1044
1045        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1046        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1047        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1048        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1049        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1050        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1051        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1052        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1053    }
1054
1055    inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1056        size_t vl = vlmax<V>();
1057        D Cv00 = set_zero<D>();
1058        D Cv01 = set_zero<D>();
1059        D Cv02 = set_zero<D>();
1060        D Cv03 = set_zero<D>();
1061
1062        for (int64_t l = 0; l < k; l += vl) {
1063            V Av0 = load<V>(A + lda * (ii + 0) + l);
1064            V Av1 = load<V>(A + lda * (ii + 1) + l);
1065            V Av2 = load<V>(A + lda * (ii + 2) + l);
1066            V Av3 = load<V>(A + lda * (ii + 3) + l);
1067
1068            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1069            Cv00 = madd(Av0, Bv0, Cv00);
1070            Cv01 = madd(Av1, Bv0, Cv01);
1071            Cv02 = madd(Av2, Bv0, Cv02);
1072            Cv03 = madd(Av3, Bv0, Cv03);
1073        }
1074
1075        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1076        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1077        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1078        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1079    }
1080
1081    inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1082        size_t vl = vlmax<V>();
1083        D Cv00 = set_zero<D>();
1084        D Cv01 = set_zero<D>();
1085        D Cv10 = set_zero<D>();
1086        D Cv11 = set_zero<D>();
1087
1088        for (int64_t l = 0; l < k; l += vl) {
1089            V Av0 = load<V>(A + lda * (ii + 0) + l);
1090            V Av1 = load<V>(A + lda * (ii + 1) + l);
1091
1092            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1093            Cv00 = madd(Av0, Bv0, Cv00);
1094            Cv01 = madd(Av1, Bv0, Cv01);
1095
1096            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1097            Cv10 = madd(Av0, Bv1, Cv10);
1098            Cv11 = madd(Av1, Bv1, Cv11);
1099        }
1100
1101        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1102        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1103        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1104        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1105    }
1106
1107    inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1108        size_t vl = vlmax<V>();
1109        D Cv00 = set_zero<D>();
1110        D Cv01 = set_zero<D>();
1111
1112        for (int64_t l = 0; l < k; l += vl) {
1113            V Av0 = load<V>(A + lda * (ii + 0) + l);
1114            V Av1 = load<V>(A + lda * (ii + 1) + l);
1115
1116            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1117            Cv00 = madd(Av0, Bv0, Cv00);
1118            Cv01 = madd(Av1, Bv0, Cv01);
1119        }
1120
1121        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1122        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1123    }
1124
1125    template <int RM, int RN>
1126    inline void gemm_bloc(int64_t ii, int64_t jj) {
1127        if constexpr (RM == 4) {
1128            if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1129            if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1130            if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1131            if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1132            if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1133            if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1134        } else if constexpr (RM == 2) {
1135            if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1136            if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1137        }
1138    }
1139
1140    template <int RM, int RN, int BM>
1141    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1142        GGML_ASSERT(m % (RM * BM) == 0);
1143        const int64_t ytiles = m / (RM * BM);
1144        const int64_t xtiles = (n + RN -1) / RN;
1145        const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1146
1147        // "round" bloc_size to "nearest" BN
1148        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1149        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1150        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1151        const int64_t nb_job = ytiles * NB_BN;
1152
1153        if (params->ith == 0) {
1154            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1155            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
1156            ggml_threadpool_chunk_set(params->threadpool, params->nth);
1157        }
1158
1159        ggml_barrier(params->threadpool);
1160
1161        int64_t job = params->ith;
1162        while (job < nb_job) {
1163            const int64_t ii = (job % ytiles) * RM * BM;
1164            const int64_t jb =  job / ytiles;
1165            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);
1166            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1167
1168            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1169            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1170            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1171
1172            for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1173                int64_t jj = jj0;
1174                for (; jj < jj1; jj += RN) {
1175                    gemm_bloc<RM, RN>(ii + bi, jj);
1176                }
1177                if constexpr (RN > 1) {
1178                    for (; jj < jj2; jj += RN - 1) {
1179                        gemm_bloc<RM, RN-1>(ii + bi, jj);
1180                    }
1181                }
1182                GGML_ASSERT(jj == jj2);
1183            }
1184
1185            job = ggml_threadpool_chunk_add(params->threadpool, 1);
1186        }
1187
1188        ggml_barrier(params->threadpool);
1189        return;
1190    }
1191
1192    const ggml_compute_params * params;
1193    const TA *const A;
1194    const TB *const B;
1195    TC *const C;
1196    const int64_t k;
1197    const int64_t lda;
1198    const int64_t ldb;
1199    const int64_t ldc;
1200};
1201#endif
1202
1203//////////////////////////////////////////////////////////////////////////////////////////
1204// QUANT ZERO MATRIX MULTIPLICATION
1205
1206#if defined(__ARM_FEATURE_DOTPROD)
1207template <typename TA>
1208class tinyBLAS_Q0_ARM {
1209  public:
1210    tinyBLAS_Q0_ARM(int64_t k,
1211                    const TA *A, int64_t lda,
1212                    const block_q8_0 *B, int64_t ldb,
1213                    float *C, int64_t ldc,
1214                    int ith, int nth)
1215        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1216    }
1217
1218    void matmul(int64_t m, int64_t n) {
1219        mnpack(0, m, 0, n);
1220    }
1221
1222  private:
1223    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1224        int64_t mc, nc, mp, np;
1225        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
1226        case 0x33:
1227            mc = 3;
1228            nc = 3;
1229            gemm<3, 3>(m0, m, n0, n);
1230            break;
1231        case 0x32:
1232            mc = 3;
1233            nc = 2;
1234            gemm<3, 2>(m0, m, n0, n);
1235            break;
1236        case 0x23:
1237            mc = 2;
1238            nc = 3;
1239            gemm<2, 3>(m0, m, n0, n);
1240            break;
1241        case 0x22:
1242            mc = 2;
1243            nc = 2;
1244            gemm<2, 2>(m0, m, n0, n);
1245            break;
1246        case 0x31:
1247            mc = 3;
1248            nc = 1;
1249            gemm<3, 1>(m0, m, n0, n);
1250            break;
1251        case 0x13:
1252            mc = 1;
1253            nc = 3;
1254            gemm<1, 3>(m0, m, n0, n);
1255            break;
1256        case 0x21:
1257            mc = 2;
1258            nc = 1;
1259            gemm<2, 1>(m0, m, n0, n);
1260            break;
1261        case 0x12:
1262            mc = 1;
1263            nc = 2;
1264            gemm<1, 2>(m0, m, n0, n);
1265            break;
1266        case 0x11:
1267            mc = 1;
1268            nc = 1;
1269            gemm<1, 1>(m0, m, n0, n);
1270            break;
1271        default:
1272            return;
1273        }
1274        mp = m0 + (m - m0) / mc * mc;
1275        np = n0 + (n - n0) / nc * nc;
1276        mnpack(mp, m, n0, np);
1277        mnpack(m0, m, np, n);
1278    }
1279
1280    template <int RM, int RN>
1281    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1282        int64_t ytiles = (m - m0) / RM;
1283        int64_t xtiles = (n - n0) / RN;
1284        int64_t tiles = xtiles * ytiles;
1285        int64_t duty = (tiles + nth - 1) / nth;
1286        int64_t start = duty * ith;
1287        int64_t end = start + duty;
1288        if (end > tiles)
1289            end = tiles;
1290        for (int64_t job = start; job < end; ++job) {
1291            int64_t ii = m0 + job / xtiles * RM;
1292            int64_t jj = n0 + job % xtiles * RN;
1293            float32x4_t Cv[RN][RM] = {};
1294            for (int64_t l = 0; l < k; ++l)
1295                for (int64_t j = 0; j < RN; ++j)
1296                    for (int64_t i = 0; i < RM; ++i)
1297                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],
1298                                               vcvtq_f32_s32(vdotq_s32(
1299                                                   vdotq_s32(vdupq_n_s32(0),
1300                                                             load_lo(A + lda * (ii + i) + l),
1301                                                             load_lo(B + ldb * (jj + j) + l)),
1302                                                   load_hi(A + lda * (ii + i) + l),
1303                                                   load_hi(B + ldb * (jj + j) + l))),
1304                                               unhalf(A[lda * (ii + i) + l].d) *
1305                                               unhalf(B[ldb * (jj + j) + l].d));
1306            for (int64_t j = 0; j < RN; ++j)
1307                for (int64_t i = 0; i < RM; ++i)
1308                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1309        }
1310    }
1311
1312    inline int8x16_t load_lo(const block_q8_0 *b) {
1313        return vld1q_s8(b->qs);
1314    }
1315
1316    inline int8x16_t load_hi(const block_q8_0 *b) {
1317        return vld1q_s8(b->qs + 16);
1318    }
1319
1320    inline int8x16_t load_lo(const block_q4_0 *b) {
1321        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
1322                                                     vdupq_n_u8(0x0f))),
1323                        vdupq_n_s8(0x8));
1324    }
1325
1326    inline int8x16_t load_hi(const block_q4_0 *b) {
1327        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
1328                        vdupq_n_s8(0x8));
1329    }
1330
1331    const TA *const A;
1332    const block_q8_0 *const B;
1333    float *const C;
1334    const int64_t k;
1335    const int64_t lda;
1336    const int64_t ldb;
1337    const int64_t ldc;
1338    const int ith;
1339    const int nth;
1340};
1341#endif // __ARM_FEATURE_DOTPROD
1342
1343#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1344template <typename TA, typename TB, typename TC>
1345class tinyBLAS_Q0_AVX {
1346  public:
1347    tinyBLAS_Q0_AVX(int64_t k,
1348                    const TA *A, int64_t lda,
1349                    const TB *B, int64_t ldb,
1350                    TC *C, int64_t ldc,
1351                    int ith, int nth)
1352        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1353        const int8_t kvalues_iq4nl[16] = {
1354            -127, -104, -83, -65,
1355            -49,  -35,  -22, -10,
1356              1,   13,   25,  38,
1357             53,   69,   89, 113
1358        };
1359
1360        iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
1361    }
1362
1363    void matmul(int64_t m, int64_t n) {
1364        mnpack(0, m, 0, n);
1365    }
1366
1367  private:
1368    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1369        int64_t mc, nc, mp, np;
1370        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
1371#if VECTOR_REGISTERS == 32
1372        case 0x44:
1373            mc = 4;
1374            nc = 4;
1375#if defined(__AVX2__) && defined(__F16C__)
1376            gemm4xN<4>(m0, m, n0, n);
1377#else
1378            gemm<4, 4>(m0, m, n0, n);
1379#endif
1380            break;
1381        case 0x43:
1382            mc = 4;
1383            nc = 3;
1384#if defined(__AVX2__) && defined(__F16C__)
1385            gemm4xN<3>(m0, m, n0, n);
1386#else
1387            gemm<4, 3>(m0, m, n0, n);
1388#endif
1389            break;
1390        case 0x34:
1391            mc = 3;
1392            nc = 4;
1393#if defined(__AVX2__) && defined(__F16C__)
1394            gemmMx4<3>(m0, m, n0, n);
1395#else
1396            gemm<3, 4>(m0, m, n0, n);
1397#endif
1398            break;
1399        case 0x33:
1400            mc = 3;
1401            nc = 3;
1402            gemm<3, 3>(m0, m, n0, n);
1403            break;
1404        case 0x42:
1405            mc = 4;
1406            nc = 2;
1407#if defined(__AVX2__) && defined(__F16C__)
1408            gemm4xN<2>(m0, m, n0, n);
1409#else
1410            gemm<4, 2>(m0, m, n0, n);
1411#endif
1412            break;
1413        case 0x24:
1414            mc = 2;
1415            nc = 4;
1416#if defined(__AVX2__) && defined(__F16C__)
1417            gemmMx4<2>(m0, m, n0, n);
1418#else
1419            gemm<2, 4>(m0, m, n0, n);
1420#endif
1421            break;
1422#else
1423        case 0x44:
1424        case 0x43:
1425        case 0x42:
1426            mc = 4;
1427            nc = 2;
1428#if defined(__AVX2__) && defined(__F16C__)
1429            gemm4xN<2>(m0, m, n0, n);
1430#else
1431            gemm<4, 2>(m0, m, n0, n);
1432#endif
1433            break;
1434        case 0x34:
1435        case 0x24:
1436            mc = 2;
1437            nc = 4;
1438#if defined(__AVX2__) && defined(__F16C__)
1439            gemmMx4<2>(m0, m, n0, n);
1440#else
1441            gemm<2, 4>(m0, m, n0, n);
1442#endif
1443            break;
1444        case 0x33:
1445#endif
1446        case 0x32:
1447            mc = 3;
1448            nc = 2;
1449            gemm<3, 2>(m0, m, n0, n);
1450            break;
1451        case 0x23:
1452            mc = 2;
1453            nc = 3;
1454            gemm<2, 3>(m0, m, n0, n);
1455            break;
1456        case 0x41:
1457            mc = 4;
1458            nc = 1;
1459#if defined(__AVX2__) && defined(__F16C__)
1460            gemm4xN<1>(m0, m, n0, n);
1461#else
1462            gemm<4, 1>(m0, m, n0, n);
1463#endif
1464            break;
1465        case 0x22:
1466            mc = 2;
1467            nc = 2;
1468            gemm<2, 2>(m0, m, n0, n);
1469            break;
1470        case 0x14:
1471            mc = 1;
1472            nc = 4;
1473#if defined(__AVX2__) && defined(__F16C__)
1474            gemmMx4<1>(m0, m, n0, n);
1475#else
1476            gemm<1, 4>(m0, m, n0, n);
1477#endif
1478            break;
1479        case 0x31:
1480            mc = 3;
1481            nc = 1;
1482            gemm<3, 1>(m0, m, n0, n);
1483            break;
1484        case 0x13:
1485            mc = 1;
1486            nc = 3;
1487            gemm<1, 3>(m0, m, n0, n);
1488            break;
1489        case 0x21:
1490            mc = 2;
1491            nc = 1;
1492            gemm<2, 1>(m0, m, n0, n);
1493            break;
1494        case 0x12:
1495            mc = 1;
1496            nc = 2;
1497            gemm<1, 2>(m0, m, n0, n);
1498            break;
1499        case 0x11:
1500            mc = 1;
1501            nc = 1;
1502            gemm<1, 1>(m0, m, n0, n);
1503            break;
1504        default:
1505            return;
1506        }
1507        mp = m0 + (m - m0) / mc * mc;
1508        np = n0 + (n - n0) / nc * nc;
1509        mnpack(mp, m, n0, np);
1510        mnpack(m0, m, np, n);
1511    }
1512
1513#if defined(__AVX2__) && defined(__F16C__)
1514// Templated functions for gemm of dimensions 4xN
1515    template <int RN>
1516    NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1517        int64_t ytiles = (m - m0) / 4;
1518        int64_t xtiles = (n - n0) / RN;
1519        int64_t tiles = xtiles * ytiles;
1520        int64_t duty = (tiles + nth - 1) / nth;
1521        int64_t start = duty * ith;
1522        int64_t end = start + duty;
1523        if (end > tiles)
1524            end = tiles;
1525        for (int64_t job = start; job < end; ++job) {
1526            int64_t ii = m0 + job / xtiles * 4;
1527            int64_t jj = n0 + job % xtiles * RN;
1528            __m256 Cv[RN][4] = {};
1529            for (int64_t l = 0; l < k; ++l) {
1530                uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
1531                // Convert delta values for four blocks to float values
1532                __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
1533                __m256i avec0 = load(A + lda * (ii + 0) + l);
1534                __m256i avec1 = load(A + lda * (ii + 1) + l);
1535                __m256i avec2 = load(A + lda * (ii + 2) + l);
1536                __m256i avec3 = load(A + lda * (ii + 3) + l);
1537                for (int64_t j = 0; j < RN; ++j) {
1538                        __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
1539                        // Computation of product of delta values for four blocks and replicate it across 256 bit lane
1540                        __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
1541                        dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
1542                        // Computation of dot product and multiplication with appropriate delta value products
1543                        Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
1544                                    updot(_mm256_sign_epi8(avec0, avec0),
1545                                          _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
1546                                    Cv[j][0]);
1547                        Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
1548                                    updot(_mm256_sign_epi8(avec1, avec1),
1549                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
1550                                    Cv[j][1]);
1551                        Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
1552                                    updot(_mm256_sign_epi8(avec2, avec2),
1553                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
1554                                    Cv[j][2]);
1555                        Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
1556                                    updot(_mm256_sign_epi8(avec3, avec3),
1557                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
1558                                    Cv[j][3]);
1559                }
1560            }
1561
1562            for (int64_t j = 0; j < RN; ++j)
1563                for (int64_t i = 0; i < 4; ++i)
1564                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1565        }
1566    }
1567
1568    // Templated functions for gemm of dimensions Mx4
1569    template <int RM>
1570    NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1571        int64_t ytiles = (m - m0) / RM;
1572        int64_t xtiles = (n - n0) / 4;
1573        int64_t tiles = xtiles * ytiles;
1574        int64_t duty = (tiles + nth - 1) / nth;
1575        int64_t start = duty * ith;
1576        int64_t end = start + duty;
1577        if (end > tiles)
1578            end = tiles;
1579        for (int64_t job = start; job < end; ++job) {
1580            int64_t ii = m0 + job / xtiles * RM;
1581            int64_t jj = n0 + job % xtiles * 4;
1582            __m256 Cv[4][RM] = {};
1583            for (int64_t l = 0; l < k; ++l) {
1584                uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
1585                // Convert delta values for four blocks to float values
1586                __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
1587                __m256i bvec0 = load(B + ldb * (jj + 0) + l);
1588                __m256i bvec1 = load(B + ldb * (jj + 1) + l);
1589                __m256i bvec2 = load(B + ldb * (jj + 2) + l);
1590                __m256i bvec3 = load(B + ldb * (jj + 3) + l);
1591                for (int64_t i = 0; i < RM; ++i) {
1592                    __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
1593                    // Computation of product of delta values for four blocks and replicate it across 256 bit lane
1594                    __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
1595                    dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
1596                    // Computation of dot product and multiplication with appropriate delta value products
1597                    Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
1598                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1599                                                            load(A + lda * (ii + i) + l)),
1600                                            _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
1601                                    Cv[0][i]);
1602                    Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
1603                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1604                                                            load(A + lda * (ii + i) + l)),
1605                                            _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
1606                                    Cv[1][i]);
1607                    Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
1608                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1609                                                            load(A + lda * (ii + i) + l)),
1610                                            _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
1611                                    Cv[2][i]);
1612                    Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
1613                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1614                                                            load(A + lda * (ii + i) + l)),
1615                                            _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
1616                                    Cv[3][i]);
1617                }
1618            }
1619            for (int64_t j = 0; j < 4; ++j)
1620                for (int64_t i = 0; i < RM; ++i)
1621                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1622        }
1623    }
1624#endif
1625
1626    template <int RM, int RN>
1627    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1628        int64_t ytiles = (m - m0) / RM;
1629        int64_t xtiles = (n - n0) / RN;
1630        int64_t tiles = xtiles * ytiles;
1631        int64_t duty = (tiles + nth - 1) / nth;
1632        int64_t start = duty * ith;
1633        int64_t end = start + duty;
1634        if (end > tiles)
1635            end = tiles;
1636        for (int64_t job = start; job < end; ++job) {
1637            int64_t ii = m0 + job / xtiles * RM;
1638            int64_t jj = n0 + job % xtiles * RN;
1639            __m256 Cv[RN][RM] = {};
1640            for (int64_t l = 0; l < k; ++l)
1641                for (int64_t j = 0; j < RN; ++j)
1642                    for (int64_t i = 0; i < RM; ++i) {
1643#if defined(__AVX2__)
1644                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1645                                                              load(A + lda * (ii + i) + l)),
1646                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
1647                                                              load(A + lda * (ii + i) + l)));
1648#else
1649                        __m128i ali0 = load0(A + lda * (ii + i) + l);
1650                        __m128i ali1 = load1(A + lda * (ii + i) + l);
1651                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
1652                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
1653
1654                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
1655                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
1656                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
1657                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
1658
1659                        // updot
1660                        const __m128i oneFill = _mm_set1_epi16(1);
1661                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
1662                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
1663                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
1664#endif
1665                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
1666                                                       unhalf(B[ldb * (jj + j) + l].d)),
1667                                                       udTmp,
1668                                                       Cv[j][i]);
1669                    }
1670            for (int64_t j = 0; j < RN; ++j)
1671                for (int64_t i = 0; i < RM; ++i)
1672                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1673        }
1674    }
1675
1676    inline __m256i load(const block_q8_0 *b) {
1677        return _mm256_loadu_si256((const __m256i *)b->qs);
1678    }
1679
1680    inline __m128i load0(const block_q8_0 *b) {
1681        return _mm_loadu_si128((const __m128i *)b->qs);
1682    }
1683
1684    inline __m128i load1(const block_q8_0 *b) {
1685        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
1686    }
1687
1688    inline __m256i load(const block_q4_0 *b) {
1689        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
1690    }
1691
1692    inline __m128i load0(const block_q4_0 *b) {
1693        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1694        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
1695    }
1696
1697    inline __m128i load1(const block_q4_0 *b) {
1698        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1699        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
1700    }
1701
1702    inline __m256i load(const block_q5_0 *b) {
1703        return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
1704    }
1705
1706    inline __m128i load0(const block_q5_0* b) {
1707        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1708        uint32_t x32;
1709        memcpy(&x32, b->qh, sizeof(uint32_t));
1710        __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
1711        __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1712                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1713                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
1714                                                                      _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
1715        bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
1716        return _mm_or_si128(qxl, bytesl);
1717    }
1718
1719    inline __m128i load1(const block_q5_0* b) {
1720        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1721        uint32_t x32;
1722        memcpy(&x32, b->qh, sizeof(uint32_t));
1723        __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
1724        __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1725                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1726                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
1727                                                                      _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
1728        bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
1729        return _mm_or_si128(qxh, bytesh);
1730    }
1731
1732    inline __m256i load(const block_iq4_nl *b) {
1733        return MM256_SET_M128I(load1(b), load0(b));
1734    }
1735
1736    inline __m128i load0(const block_iq4_nl *b) {
1737        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1738        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
1739    }
1740
1741    inline __m128i load1(const block_iq4_nl *b) {
1742        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1743        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
1744    }
1745
1746    inline __m256 updot(__m256i u, __m256i s) {
1747        __m256i res;
1748#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
1749        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1750#elif defined(__AVXVNNI__)
1751        res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
1752#else
1753        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
1754#endif
1755        return _mm256_cvtepi32_ps(res);
1756    }
1757
1758    static inline __m256i denibble(const uint8_t *p) {
1759        __m128i x = _mm_loadu_si128((const __m128i *)p);
1760        return _mm256_and_si256(_mm256_set1_epi8(15),
1761                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1762                                                        _mm_srli_epi16(x, 4), 1));
1763    }
1764
1765    static inline __m256i bittobyte(const uint8_t *p) {
1766        uint32_t x32;
1767        memcpy(&x32, p, sizeof(uint32_t));
1768        __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1769                                          _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1770                                                          _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1771                                                                              _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1772                                                                                                0x0101010101010101, 0x0000000000000000))));
1773        return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1774    }
1775
1776    const TA *const A;
1777    const TB *const B;
1778    TC *const C;
1779    const int64_t k;
1780    const int64_t lda;
1781    const int64_t ldb;
1782    const int64_t ldc;
1783    const int ith;
1784    const int nth;
1785    __m128i iq4nlt;
1786};
1787#endif // __AVX__
1788
1789//PPC Implementation
1790#if defined(__MMA__)
1791
1792#define SAVE_ACC(ACC, ii, jj) \
1793   __builtin_mma_disassemble_acc(vec_C, ACC); \
1794   for (int I = 0; I < 4; I++) { \
1795      for (int J = 0; J < 4; J++) { \
1796         *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1797      } \
1798   } \
1799
1800template<typename T>
1801struct mma_instr;
1802
1803template<>
1804struct mma_instr<ggml_bf16_t> {
1805    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1806        __builtin_mma_xvbf16ger2pp(acc, a, b);
1807    }
1808};
1809
1810template<>
1811struct mma_instr<ggml_fp16_t> {
1812    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1813        __builtin_mma_xvf16ger2pp(acc, a, b);
1814    }
1815};
1816
1817template <typename TA, typename TB, typename TC>
1818class tinyBLAS_HP16_PPC {
1819  public:
1820    tinyBLAS_HP16_PPC(int64_t k,
1821                const TA *A, int64_t lda,
1822                const TB *B, int64_t ldb,
1823                TC *C, int64_t ldc,
1824                int ith, int nth)
1825        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1826    }
1827
1828    void matmul(int64_t m, int64_t n) {
1829        mnpack(0, m, 0, n);
1830    }
1831
1832  private:
1833    void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1834        vec_t t[8], s[8];
1835        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1836        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1837        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1838        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1839
1840        if (numVec == 2) {
1841            t[0] = vec_perm(c[0], c[1], swiz1);
1842            t[1] = vec_perm(c[2], c[3], swiz1);
1843            s[0] = vec_perm(t[0], t[1], swiz3);
1844            s[1] = vec_perm(t[0], t[1], swiz4);
1845            vec_xst(s[0], 0, (vec_t*)vecOffset);
1846            vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1847        } else if (numVec == 4) {
1848            t[0] = vec_perm(c[0], c[1], swiz1);
1849            t[1] = vec_perm(c[0], c[1], swiz2);
1850            t[2] = vec_perm(c[2], c[3], swiz1);
1851            t[3] = vec_perm(c[2], c[3], swiz2);
1852            s[0] = vec_perm(t[0], t[2], swiz3);
1853            s[1] = vec_perm(t[0], t[2], swiz4);
1854            s[2] = vec_perm(t[1], t[3], swiz3);
1855            s[3] = vec_perm(t[1], t[3], swiz4);
1856            for (int i = 0; i < 4; ++i)
1857                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1858        } else if (numVec == 8) {
1859            for (int i = 0; i < 4; i += 2) {
1860                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1861                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1862            }
1863            for (int i = 4; i < 8; i += 2) {
1864                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1865                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1866            }
1867            s[0] = vec_perm(t[0], t[2], swiz3);
1868            s[1] = vec_perm(t[0], t[2], swiz4);
1869            s[2] = vec_perm(t[1], t[3], swiz3);
1870            s[3] = vec_perm(t[1], t[3], swiz4);
1871            s[4] = vec_perm(t[4], t[6], swiz3);
1872            s[5] = vec_perm(t[4], t[6], swiz4);
1873            s[6] = vec_perm(t[5], t[7], swiz3);
1874            s[7] = vec_perm(t[5], t[7], swiz4);
1875            for (int i = 0; i < 8; ++i)
1876                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1877        }
1878    }
1879
1880    void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1881        int64_t i, j;
1882        TA *aoffset = NULL;
1883        unsigned char *vecOffset = NULL;
1884        TA * aoffsets[8];
1885        vector unsigned char c_arr[8];
1886        aoffset = const_cast<TA*>(a);
1887        vecOffset = vec;
1888        j = (rows >> 3);
1889        if (j > 0) {
1890            do {
1891                if (cols == 4) {
1892                    aoffsets[0] = aoffset;
1893                    for (int it = 1; it < 4; ++it)
1894                        aoffsets[it] = aoffsets[it-1] + lda;
1895                    aoffset += 4 * lda;
1896                    for (int i = 0; i < 4; ++i)
1897                        c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1898                    vector_permute_store(c_arr, 4, vecOffset);
1899                    for (int i = 0; i<4; i++)
1900                        aoffsets[i] = aoffsets[i]+lda;
1901                    vecOffset +=64;
1902                }
1903                i = (cols >> 3);
1904                if (i > 0) {
1905                    aoffsets[0] = aoffset;
1906                    for (int it = 1; it < 8; ++it) {
1907                        aoffsets[it] = aoffsets[it-1] + lda;
1908                    }
1909                    aoffset += 8 * lda;
1910                    do {
1911                        for (int it = 0; it < 8; ++it)
1912                            c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1913                        vector_permute_store(c_arr, 8, vecOffset);
1914                        for (int it = 0; it < 8; ++it)
1915                            aoffsets[it] = aoffsets[it] + 8*lda;
1916                        vecOffset += 128;
1917                        i--;
1918                    } while(i > 0);
1919                }
1920                j--;
1921            } while(j > 0);
1922        }
1923        if (rows & 4) {
1924            aoffsets[0] = aoffset;
1925            for (int it = 1; it < 4; ++it)
1926                aoffsets[it] = aoffsets[it-1] + lda;
1927            aoffset += 4 * lda;
1928            if (cols == 4) {
1929                for (int it = 0; it < 4; ++it)
1930                    c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1931                vector_permute_store(c_arr, 2, vecOffset);
1932                for (int it = 0; it< 4; it++)
1933                    aoffsets[it] = aoffsets[it] + lda;
1934                vecOffset += 32;
1935            }
1936            i = (cols >> 3);
1937            if (i > 0) {
1938                do {
1939                    for (int it = 0; it < 4; ++it)
1940                        c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1941                    vector_permute_store(c_arr, 4, vecOffset);
1942                    for (int it = 0; it< 4; it++)
1943                        aoffsets[it] = aoffsets[it] + 8*lda;
1944                    vecOffset += 64;
1945                    i--;
1946                } while(i > 0);
1947            }
1948        }
1949        if (rows & 3) {
1950            aoffsets[0] = aoffset;
1951            for (int it = 1; it < 4; ++it)
1952                aoffsets[it] = aoffsets[it-1] + lda;
1953            if (cols == 4) {
1954                switch(rows) {
1955                    case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1956                    case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1957                    case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1958                        break;
1959                }
1960                vector_permute_store(c_arr, 2, vecOffset);
1961                for (int it = 0; it< 4; it++)
1962                     aoffsets[it] = aoffsets[it] + lda;
1963                vecOffset += 32;
1964            }
1965            i = (cols >> 3);
1966            if (i > 0) {
1967                do {
1968                    switch(rows) {
1969                        case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1970                        case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1971                        case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1972                            break;
1973                    }
1974                    vector_permute_store(c_arr, 4, vecOffset);
1975                    for (int it = 0; it <4; it++)
1976                         aoffsets[it] = aoffsets[it] + 8* lda;
1977                    vecOffset += 64;
1978                    i--;
1979                } while(i > 0);
1980            }
1981        }
1982    }
1983
1984    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1985        int64_t mc, nc, mp, np;
1986        int m_rem = MIN(m - m0, 8);
1987        int n_rem = MIN(n - n0, 8);
1988
1989        if (m_rem >= 8 && n_rem >= 8) {
1990            mc = 8;
1991            nc = 8;
1992            gemm<8,8>(m0, m, n0, n);
1993        } else if (m_rem >= 4 && n_rem >= 8) {
1994            mc = 4;
1995            nc = 8;
1996            gemm<4,8>(m0, m, n0, n);
1997        } else if (m_rem >=8 && n_rem >=4){
1998                mc = 8;
1999                nc = 4;
2000                gemm<8,4>(m0, m, n0, n);
2001        } else if ((m_rem < 4) && (n_rem >= 8)) {
2002            nc = 8;
2003            switch(m_rem) {
2004                case 1:
2005                    mc = 1;
2006                    gemm_Mx8<1>(m0, m, n0, n);
2007                    break;
2008                case 2:
2009                    mc = 2;
2010                    gemm_Mx8<2>(m0, m, n0, n);
2011                    break;
2012                case 3:
2013                    mc = 3;
2014                    gemm_Mx8<3>(m0, m, n0, n);
2015                    break;
2016                default:
2017                    return;
2018            }
2019        } else if (m_rem >= 4 && n_rem >= 4) {
2020            mc = 4;
2021            nc = 4;
2022            gemm_small<4, 4>(m0, m, n0, n);
2023        } else if ((m_rem > 4) && (n_rem < 4)) {
2024            mc = 4;
2025            switch(n_rem) {
2026                case 1:
2027                    nc = 1;
2028                    gemm_small<4, 1>(m0, m, n0, n);
2029                    break;
2030                case 2:
2031                    nc = 2;
2032                    gemm_small<4, 2>(m0, m, n0, n);
2033                    break;
2034                case 3:
2035                    nc = 3;
2036                    gemm_small<4, 3>(m0, m, n0, n);
2037                    break;
2038
2039                default:
2040                    return;
2041            }
2042        } else {
2043            switch((m_rem << 4) | n_rem) {
2044                case 0x43:
2045                    mc = 4;
2046                    nc = 3;
2047                    gemm_small<4, 3>(m0, m, n0, n);
2048                    break;
2049                case 0x42:
2050                    mc = 4;
2051                    nc = 2;
2052                    gemm_small<4, 2>(m0, m, n0, n);
2053                    break;
2054                case 0x41:
2055                    mc = 4;
2056                    nc = 1;
2057                    gemm_small<4, 1>(m0, m, n0, n);
2058                    break;
2059                case 0x34:
2060                    mc = 3;
2061                    nc = 4;
2062                    gemm_small<3, 4>(m0, m, n0, n);
2063                    break;
2064                case 0x33:
2065                    mc = 3;
2066                    nc = 3;
2067                    gemm_small<3, 3>(m0, m, n0, n);
2068                    break;
2069                case 0x32:
2070                    mc = 3;
2071                    nc = 2;
2072                    gemm_small<3, 2>(m0, m, n0, n);
2073                    break;
2074                case 0x31:
2075                    mc = 3;
2076                    nc = 1;
2077                    gemm_small<3, 1>(m0, m, n0, n);
2078                    break;
2079                case 0x24:
2080                    mc = 2;
2081                    nc = 4;
2082                    gemm_small<2,4>(m0, m, n0, n);
2083                    break;
2084                case 0x23:
2085                    mc = 2;
2086                    nc = 3;
2087                    gemm_small<2, 3>(m0, m, n0, n);
2088                    break;
2089                case 0x22:
2090                    mc = 2;
2091                    nc = 2;
2092                    gemm_small<2, 2>(m0, m, n0, n);
2093                    break;
2094                case 0x21:
2095                    mc = 2;
2096                    nc = 1;
2097                    gemm_small<2, 1>(m0, m, n0, n);
2098                    break;
2099                case 0x14:
2100                    mc = 1;
2101                    nc = 4;
2102                    gemm_small<1, 4>(m0, m, n0, n);
2103                    break;
2104                case 0x13:
2105                    mc = 1;
2106                    nc = 3;
2107                    gemm_small<1, 3>(m0, m, n0, n);
2108                    break;
2109                case 0x12:
2110                    mc = 1;
2111                    nc = 2;
2112                    gemm_small<1, 2>(m0, m, n0, n);
2113                    break;
2114                case 0x11:
2115                    mc = 1;
2116                    nc = 1;
2117                    gemm_small<1, 1>(m0, m, n0, n);
2118                    break;
2119                default:
2120                    return;
2121            }
2122        }
2123        mp = m0 + (m - m0) / mc * mc;
2124        np = n0 + (n - n0) / nc * nc;
2125        mnpack(mp, m, n0, np);
2126        mnpack(m0, m, np, n);
2127    }
2128
2129    void KERNEL_4x8(int64_t ii, int64_t jj) {
2130        vec_t vec_A[4], vec_B[8] , vec_C[4];
2131        acc_t acc_0, acc_1;
2132        __builtin_mma_xxsetaccz(&acc_0);
2133        __builtin_mma_xxsetaccz(&acc_1);
2134        for (int l = 0; l < k; l+=8) {
2135            packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
2136            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
2137            for (int x = 0; x < 4; x++) {
2138                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2139                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2140            }
2141        }
2142        SAVE_ACC(&acc_0, ii, jj);
2143        SAVE_ACC(&acc_1, ii, jj+4);
2144    }
2145
2146    void KERNEL_8x4(int64_t ii, int64_t jj) {
2147        vec_t vec_A[8], vec_B[4] , vec_C[4];
2148        acc_t acc_0, acc_1;
2149        __builtin_mma_xxsetaccz(&acc_0);
2150        __builtin_mma_xxsetaccz(&acc_1);
2151        for (int l = 0; l < k; l+=8) {
2152            packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
2153            packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
2154            for (int x = 0; x < 4; x++) {
2155                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2156                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2157            }
2158        }
2159        SAVE_ACC(&acc_0, ii, jj);
2160        SAVE_ACC(&acc_1, ii+4, jj);
2161    }
2162
2163
2164    void KERNEL_8x8(int64_t ii, int64_t jj) {
2165        vec_t vec_A[8], vec_B[8], vec_C[4];
2166        acc_t acc_0, acc_1, acc_2, acc_3;
2167        __builtin_mma_xxsetaccz(&acc_0);
2168        __builtin_mma_xxsetaccz(&acc_1);
2169        __builtin_mma_xxsetaccz(&acc_2);
2170        __builtin_mma_xxsetaccz(&acc_3);
2171        for (int l = 0; l < k; l+=8) {
2172            packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
2173            packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
2174            for (int x = 0; x < 4; x++) {
2175                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2176                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2177                mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
2178                mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
2179            }
2180        }
2181
2182        SAVE_ACC(&acc_0, ii, jj);
2183        SAVE_ACC(&acc_1, ii, jj+4);
2184        SAVE_ACC(&acc_2, ii+4, jj);
2185        SAVE_ACC(&acc_3, ii+4, jj+4);
2186    }
2187
2188    template<int RM, int RN>
2189    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2190        int64_t ytiles = (m - m0) / RM;
2191        int64_t xtiles = (n - n0) / RN;
2192        int64_t tiles = xtiles * ytiles;
2193        int64_t duty = (tiles + nth - 1) / nth;
2194        int64_t start = duty * ith;
2195        int64_t end = start + duty;
2196        if (end > tiles)
2197            end = tiles;
2198        for (int64_t job = start; job < end; ++job) {
2199            int64_t ii = m0 + job / xtiles * RM;
2200            int64_t jj = n0 + job % xtiles * RN;
2201            vec_t vec_C[4];
2202            acc_t acc_0;
2203            __builtin_mma_xxsetaccz(&acc_0);
2204            vec_t vec_A[2], vec_B[2];
2205            for (int l=0; l<k; l+=4) {
2206                packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
2207                packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
2208                for (int x = 0; x<2; x++) {
2209                    mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2210                }
2211            }
2212            __builtin_mma_disassemble_acc(vec_C, &acc_0);
2213            for (int I = 0; I < RM; I++) {
2214                for (int J = 0; J < RN; J++) {
2215                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2216                }
2217            }
2218        }
2219    }
2220
2221    template<int RM>
2222    void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2223        int RN = 8;
2224        int64_t ytiles = (m - m0) / RM;
2225        int64_t xtiles = (n - n0) / RN;
2226        int64_t tiles = xtiles * ytiles;
2227        int64_t duty = (tiles + nth - 1) / nth;
2228        int64_t start = duty * ith;
2229        int64_t end = start + duty;
2230        if (end > tiles)
2231            end = tiles;
2232        for (int64_t job = start; job < end; ++job) {
2233            int64_t ii = m0 + job / xtiles * RM;
2234            int64_t jj = n0 + job % xtiles * RN;
2235            vec_t vec_C[4];
2236            acc_t acc_0, acc_1;
2237            __builtin_mma_xxsetaccz(&acc_0);
2238            __builtin_mma_xxsetaccz(&acc_1);
2239            vec_t vec_A[4], vec_B[8];
2240            for (int l=0; l<k; l+=8) {
2241                packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
2242                packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
2243                for (int x = 0; x<4; x++) {
2244                    mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2245                    mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2246                }
2247            }
2248            __builtin_mma_disassemble_acc(vec_C, &acc_0);
2249            for (int I = 0; I < RM; I++) {
2250                for (int J = 0; J < 4; J++) {
2251                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2252                }
2253            }
2254            __builtin_mma_disassemble_acc(vec_C, &acc_1);
2255            for (int I = 0; I < RM; I++) {
2256                for (int J = 0; J < 4; J++) {
2257                    *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2258                }
2259            }
2260        }
2261    }
2262
2263    template<int RM, int RN>
2264    inline void kernel(int64_t ii, int64_t jj) {
2265       if constexpr(RM == 4 && RN == 8) {
2266          KERNEL_4x8(ii,jj);
2267       } else if constexpr(RM == 8 && RN == 8) {
2268          KERNEL_8x8(ii,jj);
2269       } else if constexpr(RM == 8 && RN == 4) {
2270          KERNEL_8x4(ii,jj);
2271       } else {
2272          assert(false && "RN/RM values not supported");
2273       }
2274    }
2275
2276    template <int RM, int RN>
2277    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2278        int64_t ytiles = (m - m0) / RM;
2279        int64_t xtiles = (n - n0) / RN;
2280        int64_t tiles = xtiles * ytiles;
2281        int64_t duty = (tiles + nth - 1) / nth;
2282        int64_t start = duty * ith;
2283        int64_t end = start + duty;
2284        if (end > tiles)
2285            end = tiles;
2286        for (int64_t job = start; job < end; ++job) {
2287            int64_t ii = m0 + job / xtiles * RM;
2288            int64_t jj = n0 + job % xtiles * RN;
2289            kernel<RM, RN>(ii, jj);
2290        }
2291    }
2292
2293    const TA *const A;
2294    const TB *const B;
2295    TC *C;
2296    const int64_t k;
2297    const int64_t lda;
2298    const int64_t ldb;
2299    const int64_t ldc;
2300    const int ith;
2301    const int nth;
2302};
2303
2304    template <typename TA>
2305    tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
2306        const TA *A, int64_t lda,
2307        const block_q8_0 *B, int64_t ldb,
2308        float *C, int64_t ldc,
2309        int ith, int nth)
2310        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2311                kc = 64;
2312    }
2313
2314    template<typename TA>
2315    void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
2316        int mc = 64; int nc = 64;
2317        if (n % 8 == 0 && n < nc) {
2318                nc = n;
2319                mc = 32 ;
2320                kc = 32;
2321        }
2322        const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
2323        if (is_aligned) {
2324            this->matmul_tiled_q0(m, n, mc, nc, kc);
2325        } else {
2326            mnpack(0, m, 0, n);
2327        }
2328    }
2329
2330   template<typename TA>
2331   template<int size>
2332   void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
2333        int64_t i, j;
2334        TA *aoffset = NULL;
2335        int8_t *vecOffset = NULL;
2336        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2337        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2338        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2339        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2340        aoffset = const_cast<TA*>(a);
2341        vecOffset = vec;
2342        j = (rows >> 3);
2343        if (j > 0) {
2344            do {
2345                aoffset1 = aoffset;
2346                aoffset2 = aoffset1 + lda;
2347                aoffset3 = aoffset2 + lda;
2348                aoffset4 = aoffset3 + lda;
2349                aoffset5 = aoffset4 + lda;
2350                aoffset6 = aoffset5 + lda;
2351                aoffset7 = aoffset6 + lda;
2352                aoffset8 = aoffset7 + lda;
2353                aoffset += 8 * lda;
2354                i = (cols >> 2);
2355                if (i > 0) {
2356                    do {
2357                        c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2358                        c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2359                        c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2360                        c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2361                        c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
2362                        c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
2363                        c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
2364                        c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
2365
2366                        process_q4_elements(c1, &comparray[0]);
2367                        process_q4_elements(c2, &comparray[1]);
2368                        process_q4_elements(c3, &comparray[2]);
2369                        process_q4_elements(c4, &comparray[3]);
2370                        process_q4_elements(c5, &comparray[4]);
2371                        process_q4_elements(c6, &comparray[5]);
2372                        process_q4_elements(c7, &comparray[6]);
2373                        process_q4_elements(c8, &comparray[7]);
2374                        vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2375                        vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2376                        vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
2377                        vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
2378                        aoffset1 += lda;
2379                        aoffset2 += lda;
2380                        aoffset3 += lda;
2381                        aoffset4 += lda;
2382                        aoffset5 += lda;
2383                        aoffset6 += lda;
2384                        aoffset7 += lda;
2385                        aoffset8 += lda;
2386                        vecOffset += 256;
2387                        i--;
2388                    } while (i > 0);
2389                }
2390                j--;
2391            } while (j > 0);
2392        }
2393
2394        if (rows & 4) {
2395            aoffset1 = aoffset;
2396            aoffset2 = aoffset1 + lda;
2397            aoffset3 = aoffset2 + lda;
2398            aoffset4 = aoffset3 + lda;
2399            aoffset += 4 * lda;
2400            i = (cols >> 2);
2401            if (i > 0) {
2402                do {
2403                    c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2404                    c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2405                    c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2406                    c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2407
2408                    process_q4_elements(c1, &comparray[0]);
2409                    process_q4_elements(c2, &comparray[1]);
2410                    process_q4_elements(c3, &comparray[2]);
2411                    process_q4_elements(c4, &comparray[3]);
2412                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2413                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2414                    aoffset1 += lda;
2415                    aoffset2 += lda;
2416                    aoffset3 += lda;
2417                    aoffset4 += lda;
2418                    vecOffset += 128;
2419                    i--;
2420                } while (i > 0);
2421            }
2422        }
2423
2424        if (rows & 3) {
2425            aoffset1 = aoffset;
2426            aoffset2 = aoffset1 + lda;
2427            aoffset3 = aoffset2 + lda;
2428            i = (cols >> 2);
2429            if (i > 0) {
2430                do {
2431                    switch(rows) {
2432                        case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2433                        case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2434                        case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2435                            break;
2436                    }
2437                    process_q4_elements(c1, &comparray[0]);
2438                    process_q4_elements(c2, &comparray[1]);
2439                    process_q4_elements(c3, &comparray[2]);
2440                    process_q4_elements(c4, &comparray[3]);
2441                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2442                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2443                    aoffset1 += lda;
2444                    aoffset2 += lda;
2445                    aoffset3 += lda;
2446                    vecOffset += 128;
2447                    i--;
2448                } while(i > 0);
2449            }
2450        }
2451    }
2452
2453    template<typename TA>
2454    template<typename VA, typename VB>
2455    void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2456        int64_t i, j;
2457        block_q8_0 *aoffset = NULL;
2458        VA *vecOffset = NULL;
2459        block_q8_0* aoffsets[8];
2460        __vector_pair arr[8];
2461        VB c[8][2] = {0};
2462        VB c1[8] = {0}; VB c2[8] = {0};
2463        aoffset = const_cast<block_q8_0*>(a);
2464        vecOffset = vec;
2465        j = (rows >> 3);
2466        if (j > 0) {
2467            do {
2468                aoffsets[0] = aoffset;
2469                for (int it = 1; it < 8; it++)
2470                    aoffsets[it] = aoffsets[it-1] + lda;
2471                aoffset += 8 * lda;
2472
2473                i = (cols >> 3);
2474                if (i > 0) {
2475                do {
2476                    for (int it = 0; it < 8; it++) {
2477                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2478                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2479                        c1[it] = c[it][0];
2480                        c2[it] = c[it][1];
2481                    }
2482                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2483                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2484                    vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
2485                    vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
2486                    for (int it = 0; it < 8; it++)
2487                        aoffsets[it] += lda;
2488                    vecOffset += 256;
2489                    i--;
2490               } while(i > 0);
2491            }
2492            j--;
2493        } while(j > 0);
2494    }
2495    if (rows & 4) {
2496            aoffsets[0]  = aoffset;
2497            for (int it = 1; it < 4; it++ )
2498                aoffsets[it] = aoffsets[it-1] + lda;
2499            aoffset += 4 * lda;
2500        i = (cols >> 3);
2501            if (i > 0) {
2502               do {
2503                    for (int it = 0; it < 4; it++) {
2504                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2505                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2506                        c1[it] = c[it][0];
2507                        c2[it] = c[it][1];
2508                    }
2509                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2510                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2511                    for (int it = 0; it < 4; it++) {
2512                        aoffsets[it] += lda;
2513                    }
2514                    vecOffset += 128;
2515                    i--;
2516               } while(i > 0);
2517            }
2518        }
2519
2520        if (rows & 3) {
2521            aoffsets[0]  = aoffset;
2522            for (int it = 1; it < 3; it++ )
2523                aoffsets[it] = aoffsets[it-1] + lda;
2524            i = (cols >> 3);
2525            if (i > 0) {
2526                do {
2527                    switch(rows) {
2528                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
2529                                __builtin_vsx_disassemble_pair(c[2], &arr[2]);
2530                                c1[2] = c[2][0]; c2[2] = c[2][1];
2531                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
2532                                __builtin_vsx_disassemble_pair(c[1], &arr[1]);
2533                                c1[1] = c[1][0]; c2[1] = c[1][1];
2534                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
2535                                __builtin_vsx_disassemble_pair(c[0], &arr[0]);
2536                                c1[0] = c[0][0]; c2[0] = c[0][1];
2537                                break;
2538                    }
2539                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2540                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2541                    for (int it = 0; it < 3; it++)
2542                         aoffsets[it] += lda;
2543                    vecOffset += 128;
2544                    i--;
2545               } while(i > 0);
2546            }
2547        }
2548    }
2549
2550    template<typename TA>
2551    void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2552        int m_rem = MIN(m - m0, 16);
2553        int n_rem = MIN(n - n0, 16);
2554
2555        int mc = 0, nc = 0;
2556
2557        if (m_rem >= 8 && n_rem >= 8) {
2558           mc = 8;
2559           nc = 8;
2560           gemm<8, 8>(m0, m, n0, n);
2561        } else if (m_rem >= 4 && n_rem >= 8) {
2562            mc = 4;
2563            nc = 8;
2564            gemm<4, 8>(m0, m, n0, n);
2565        } else if (m_rem >= 8 && n_rem >= 4) {
2566            mc = 8;
2567            nc = 4;
2568            gemm<8, 4>(m0, m, n0, n);
2569        } else if (m_rem >= 4 && n_rem >= 4) {
2570            mc = 4;
2571            nc = 4;
2572            gemm_small(m0, m, n0, n, mc, nc);
2573        } else {
2574            mc = (m_rem >= 4) ? 4 : m_rem;
2575            nc = (n_rem >= 4) ? 4 : n_rem;
2576            if (mc == 0 || nc == 0)
2577               return;
2578            gemm_small(m0, m, n0, n, mc, nc);
2579        }
2580
2581        int64_t mp = m0 + ((m - m0) / mc) * mc;
2582        int64_t np = n0 + ((n - n0) / nc) * nc;
2583        mnpack(mp, m, n0, np);
2584        mnpack(m0, m, np, n);
2585    }
2586
2587
2588    template<typename TA>
2589    void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
2590        vec_t vec_A[8], vec_B[16] = {0};
2591        acc_t acc_0, acc_1;
2592        std::array<int, 4> comparray {};
2593        vector float fin_res[8] = {0};
2594        vector float vs[8] = {0};
2595        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2596        for (int l = 0; l < k; l++) {
2597            __builtin_mma_xxsetaccz(&acc_0);
2598            __builtin_mma_xxsetaccz(&acc_1);
2599            if (std::is_same_v<TA, block_q4_0>) {
2600               packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2601            } else {
2602               packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2603            }
2604            packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2605            for(int x = 0; x < 8; x++) {
2606                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2607                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
2608            }
2609            for (int I = 0; I<4; I++) {
2610                for (int J = 0; J<4; J++) {
2611                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2612                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2613                }
2614            }
2615            if (!isAblock_q4) {
2616                auto aoffset = A+(ii*lda)+l;
2617                for (int i = 0; i < 4; i++) {
2618                    comparray[i] = 0;
2619                    int ca = 0;
2620                    auto *at = aoffset->qs;
2621                    for (int j = 0; j < 32; j++)
2622                        ca += (int)*at++;
2623                    comparray[i] = ca;
2624                    aoffset += lda;
2625                }
2626            }
2627            compute(&acc_0, 0, 0, comparray, vs, fin_res);
2628            compute(&acc_1, 0, 4, comparray, vs, fin_res);
2629        }
2630        save_res(ii, jj, 0, fin_res);
2631        save_res(ii, jj+4, 4, fin_res);
2632    }
2633
2634    template<typename TA>
2635    void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
2636        vec_t vec_A[16], vec_B[8] = {0};
2637        acc_t acc_0, acc_1;
2638        std::array<int, 8> comparray {};
2639        vector float fin_res[8] = {0};
2640        vector float vs[8] = {0};
2641        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2642        for (int l = 0; l < k; l++) {
2643            __builtin_mma_xxsetaccz(&acc_0);
2644            __builtin_mma_xxsetaccz(&acc_1);
2645            if (std::is_same_v<TA, block_q4_0>) {
2646               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2647            } else {
2648               packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2649            }
2650            packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2651            for(int x = 0; x < 8; x++) {
2652                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2653                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2654            }
2655            for (int I = 0; I<8; I++) {
2656                for (int J = 0; J<4; J++) {
2657                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2658                }
2659            }
2660            if (!isAblock_q4) {
2661                auto aoffset = A+(ii*lda)+l;
2662                for (int i = 0; i < 8; i++) {
2663                    comparray[i] = 0;
2664                    int ca = 0;
2665                    auto *at = aoffset->qs;
2666                    for (int j = 0; j < 32; j++)
2667                        ca += (int)*at++;
2668                    comparray[i] = ca;
2669                    aoffset += lda;
2670                }
2671            }
2672            compute(&acc_0, 0, 0, comparray, vs, fin_res);
2673            compute(&acc_1, 4, 4, comparray, vs, fin_res);
2674        }
2675        save_res(ii, jj, 0, fin_res);
2676        save_res(ii+4, jj, 4, fin_res);
2677    }
2678
2679    template<typename TA>
2680    void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
2681        vec_t vec_A[16], vec_B[16] = {0};
2682        acc_t acc_0, acc_1, acc_2, acc_3;
2683        acc_t acc_4, acc_5, acc_6, acc_7;
2684        std::array<int, 8> comparray {};
2685        vector float fin_res[16] = {0};
2686        vector float vs[16] = {0};
2687        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2688        for (int l = 0; l < k; l++) {
2689            __builtin_mma_xxsetaccz(&acc_0);
2690            __builtin_mma_xxsetaccz(&acc_1);
2691            __builtin_mma_xxsetaccz(&acc_2);
2692            __builtin_mma_xxsetaccz(&acc_3);
2693            if (std::is_same_v<TA, block_q4_0>) {
2694               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2695            } else {
2696               packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2697            }
2698            packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2699            for(int x = 0; x < 8; x++) {
2700                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2701                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2702                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
2703                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
2704            }
2705            for (int I = 0; I<8; I++) {
2706                for (int J = 0; J<4; J++) {
2707                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2708                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2709                }
2710            }
2711            if (!isAblock_q4) {
2712                auto aoffset = A+(ii*lda)+l;
2713                for (int i = 0; i < 8; i++) {
2714                    comparray[i] = 0;
2715                    int ca = 0;
2716                    auto *at = aoffset->qs;
2717                    for (int j = 0; j < 32; j++)
2718                        ca += (int)*at++;
2719                    comparray[i] = ca;
2720                    aoffset += lda;
2721                }
2722            }
2723            compute(&acc_0, 0, 0, comparray, vs, fin_res);
2724            compute(&acc_1, 4, 4, comparray, vs, fin_res);
2725            compute(&acc_2, 0, 8, comparray, vs, fin_res);
2726            compute(&acc_3, 4, 12, comparray, vs, fin_res);
2727        }
2728        save_res(ii, jj, 0, fin_res);
2729        save_res(ii+4, jj, 4, fin_res);
2730        save_res(ii, jj+4, 8, fin_res);
2731        save_res(ii+4, jj+4, 12, fin_res);
2732    }
2733
2734    template<typename TA>
2735    void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2736        int64_t ytiles = (m - m0) / RM;
2737        int64_t xtiles = (n - n0) / RN;
2738        int64_t tiles = xtiles * ytiles;
2739        int64_t duty = (tiles + nth - 1) / nth;
2740        int64_t start = duty * ith;
2741        int64_t end = start + duty;
2742        vec_t vec_A[8] = {0}, vec_B[8] = {0};
2743        vector signed int vec_C[4];
2744        acc_t acc_0;
2745        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2746
2747        if (end > tiles)
2748            end = tiles;
2749        for (int64_t job = start; job < end; ++job) {
2750            int64_t ii = m0 + job / xtiles * RM;
2751            int64_t jj = n0 + job % xtiles * RN;
2752            std::array<int, 4> comparray{};
2753            vector float res[4] = {0};
2754            vector float fin_res[4] = {0};
2755            vector float vs[4] = {0};
2756            vector float CA[4] = {0};
2757            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
2758            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
2759            for (int l = 0; l < k; l++) {
2760                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2761                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2762                __builtin_mma_xxsetaccz(&acc_0);
2763                if (isAblock_q4) {
2764                   packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2765                } else {
2766                   packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2767                }
2768                packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2769                for(int x = 0; x < 8; x+=4) {
2770                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2771                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
2772                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
2773                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
2774                }
2775                for (int I = 0; I<RM; I++) {
2776                    for (int J = 0; J<RN; J++) {
2777                        *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2778                    }
2779                }
2780                __builtin_mma_disassemble_acc(vec_C, &acc_0);
2781                if (!isAblock_q4) {
2782                    auto aoffset = A+(ii*lda)+l;
2783                    for (int i = 0; i < RM; i++) {
2784                        comparray[i] = 0;
2785                        int ca = 0;
2786                        auto *at = aoffset->qs;
2787                        for (int j = 0; j < 32; j++)
2788                            ca += (int)*at++;
2789                        comparray[i] = ca;
2790                        aoffset += lda;
2791                    }
2792                }
2793                for (int i = 0; i < RM; i++) {
2794                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
2795                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2796                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2797                }
2798            }
2799            save_res(ii, jj, 0, fin_res, RM, RN);
2800        }
2801    }
2802
2803    template<typename TA>
2804    template <int RM, int RN>
2805    NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2806        int64_t ytiles = (m - m0) / RM;
2807        int64_t xtiles = (n - n0) / RN;
2808        int64_t tiles = xtiles * ytiles;
2809        int64_t duty = (tiles + nth - 1) / nth;
2810        int64_t start = duty * ith;
2811        int64_t end = start + duty;
2812        if (end > tiles)
2813            end = tiles;
2814        for (int64_t job = start; job < end; ++job) {
2815            int64_t ii = m0 + job / xtiles * RM;
2816            int64_t jj = n0 + job % xtiles * RN;
2817            this->kernel<RM, RN>(ii, jj);
2818        }
2819    }
2820
2821template class tinyBLAS_Q0_PPC<block_q4_0>;
2822template class tinyBLAS_Q0_PPC<block_q8_0>;
2823
2824class tinyBLAS_PPC {
2825  public:
2826    tinyBLAS_PPC(int64_t k,
2827                const float * A, int64_t lda,
2828                const float * B, int64_t ldb,
2829                float * C, int64_t ldc,
2830                int ith, int nth)
2831        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2832    }
2833
2834    void matmul(int64_t m, int64_t n) {
2835        int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2836        if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2837            matmul_tiled(m, n, mc, nc, kc);
2838        } else {
2839            mnpack(0, m, 0, n);
2840        }
2841    }
2842
2843  private:
2844
2845    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2846        vec_t vec_C[4];
2847        __builtin_mma_disassemble_acc(vec_C, ACC);
2848        for (int I = 0; I < 4; I++) {
2849            for (int J = 0; J < 4; J++) {
2850                *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2851            }
2852        }
2853    }
2854
2855    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2856        vec_t vec_C[4];
2857        __builtin_mma_disassemble_acc(vec_C, ACC);
2858        for (int I = 0; I < 4; I++) {
2859            for (int J = 0; J < 4; J++) {
2860                float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2861                *c_ptr += *((float *)&vec_C[I]+J);
2862            }
2863        }
2864    }
2865
2866    inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2867        vector float t1, t2, t3, t4, t5, t6, t7, t8;
2868        t1 = vec_mergeh(src[0], src[1]);
2869        t2 = vec_mergeh(src[2], src[3]);
2870        t3 = vec_mergel(src[0], src[1]);
2871        t4 = vec_mergel(src[2], src[3]);
2872
2873        t5 = vec_xxpermdi(t1, t2, 0);
2874        t6 = vec_xxpermdi(t1, t2, 3);
2875        t7 = vec_xxpermdi(t3, t4, 0);
2876        t8 = vec_xxpermdi(t3, t4, 3);
2877
2878        vec_xst(t5, 0, vecOffset);
2879        vec_xst(t6, 0, vecOffset + 4);
2880        vec_xst(t7, 0, vecOffset + 8);
2881        vec_xst(t8, 0, vecOffset + 12);
2882    }
2883
2884    inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2885        vector float t1, t2, t3, t4, t5, t6, t7, t8;
2886        t1 = vec_mergeh(src[0], src[1]);
2887        t2 = vec_mergeh(src[2], src[3]);
2888        t3 = vec_mergeh(src[4], src[5]);
2889        t4 = vec_mergeh(src[6], src[7]);
2890
2891        t5 = vec_xxpermdi(t1, t2, 0);
2892        t6 = vec_xxpermdi(t3, t4, 0);
2893        t7 = vec_xxpermdi(t1, t2, 3);
2894        t8 = vec_xxpermdi(t3, t4, 3);
2895
2896        vec_xst(t5, 0, vecOffset);
2897        vec_xst(t6, 0, vecOffset + 4);
2898        vec_xst(t7, 0, vecOffset + 8);
2899        vec_xst(t8, 0, vecOffset + 12);
2900
2901        t1 = vec_mergel(src[0], src[1]);
2902        t2 = vec_mergel(src[2], src[3]);
2903        t3 = vec_mergel(src[4], src[5]);
2904        t4 = vec_mergel(src[6], src[7]);
2905
2906        t5 = vec_xxpermdi(t1, t2, 0);
2907        t6 = vec_xxpermdi(t3, t4, 0);
2908        t7 = vec_xxpermdi(t1, t2, 3);
2909        t8 = vec_xxpermdi(t3, t4, 3);
2910
2911        vec_xst(t5, 0, vecOffset + 16);
2912        vec_xst(t6, 0, vecOffset + 20);
2913        vec_xst(t7, 0, vecOffset + 24);
2914        vec_xst(t8, 0, vecOffset + 28);
2915    }
2916
2917    void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2918        int64_t i, j;
2919        float * aoffsets[8];
2920        float * aoffset = NULL, * boffset = NULL;
2921        __vector_pair arr[8];
2922        vector float c[8][2] = {0};
2923        vector float c1[8] = {0};
2924        vector float c2[8] = {0};
2925        aoffset = const_cast<float *>(a);
2926        boffset = vec;
2927        j = (rows >> 3);
2928        if (j > 0) {
2929            do {
2930                aoffsets[0] = aoffset;
2931                for (int it = 1; it < 8; it++)
2932                    aoffsets[it] = aoffsets[it-1] + lda;
2933                aoffset += 8 * lda;
2934                i = (cols >> 3);
2935                if (i > 0) {
2936                    do {
2937                        for (int it = 0; it < 8; it++) {
2938                            arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2939                            __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2940                            c1[it] = c[it][0];
2941                            c2[it] = c[it][1];
2942                        }
2943
2944                        vector_permute_store_8(c1, boffset);
2945                        vector_permute_store_8(c2, boffset + 32);
2946                        boffset += 64;
2947                        i--;
2948                        if (i > 0) {
2949                           for (int it = 0; it < 8; it++) {
2950                               aoffsets[it] = aoffsets[it] + 8;
2951                           }
2952                        }
2953                    } while(i > 0);
2954                }
2955                if (cols & 4) {
2956                    for (int it = 0; it < 8 ; it++)
2957                        c1[it] = vec_xl(0, aoffsets[it]);
2958                    vector_permute_store_8(c1, boffset);
2959                }
2960            j--;
2961            } while(j > 0);
2962        }
2963
2964        if (rows & 4) {
2965            aoffsets[0] = aoffset;
2966            for (int it = 1; it < 4; it++)
2967                aoffsets[it] = aoffsets[it-1] + lda;
2968            aoffset += 4 * lda;
2969            i = (cols >> 3);
2970            if (i > 0) {
2971                do {
2972                    for (int it = 0; it < 4; it++) {
2973                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2974                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2975                        c1[it] = c[it][0];
2976                        c2[it] = c[it][1];
2977                    }
2978                    vector_permute_store_4(c1, boffset);
2979                    vector_permute_store_4(c2, boffset + 16);
2980                    for (int it = 0; it < 4; it++)
2981                        aoffsets[it] += 8 * lda;
2982                    boffset += 32;
2983                    i--;
2984                } while(i > 0);
2985            }
2986
2987            if (cols & 4) {
2988               for (int it = 0; it < 4; it++)
2989                   c1[it] = vec_xl(0, aoffsets[it]);
2990                vector_permute_store_4(c1, boffset);
2991            }
2992        }
2993        if (rows & 3) {
2994            aoffsets[0] = aoffset;
2995            for (int it = 1; it < 3; it++)
2996                aoffsets[it] = aoffsets[it-1] + lda;
2997            if (cols & 4) {
2998                for (int it = 0; it < 3; it++)
2999                    c1[it] = vec_xl(0, aoffsets[it]);
3000                vector_permute_store_4(c1, boffset);
3001            }
3002        }
3003    }
3004
3005    void KERNEL_4x4(int64_t ii, int64_t jj) {
3006        vec_t vec_A[4], vec_B[4], vec_C[4];
3007        acc_t acc_0;
3008        __builtin_mma_xxsetaccz(&acc_0);
3009        for (int l = 0; l < k; l += 4) {
3010            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3011            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
3012            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3013            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
3014            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
3015            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
3016        }
3017        save_acc(&acc_0, ii, jj);
3018    }
3019
3020    void KERNEL_4x8(int64_t ii, int64_t jj) {
3021        vec_t vec_A[4], vec_B[8], vec_C[4];
3022        acc_t acc_0, acc_1;
3023        __builtin_mma_xxsetaccz(&acc_0);
3024        __builtin_mma_xxsetaccz(&acc_1);
3025        for (int64_t l = 0; l < k; l += 4) {
3026            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3027            packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
3028            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
3029            __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
3030            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
3031            __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
3032            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
3033            __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
3034            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
3035            __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
3036        }
3037        save_acc(&acc_0, ii, jj);
3038        save_acc(&acc_1, ii, jj + 4);
3039    }
3040
3041    void KERNEL_8x4(int64_t ii, int64_t jj) {
3042        vec_t vec_A[8], vec_B[4], vec_C[4];
3043        acc_t acc_0, acc_1;
3044        __builtin_mma_xxsetaccz(&acc_0);
3045        __builtin_mma_xxsetaccz(&acc_1);
3046        for (int64_t l = 0; l < k; l += 4) {
3047            packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
3048            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
3049            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3050            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3051            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
3052            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
3053            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
3054            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
3055            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
3056            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
3057        }
3058        save_acc(&acc_0, ii, jj);
3059        save_acc(&acc_1, ii + 4, jj);
3060    }
3061
3062    void KERNEL_8x8(int64_t ii, int64_t jj) {
3063        vec_t vec_A[16], vec_B[16], vec_C[4];
3064        acc_t acc_0, acc_1, acc_2, acc_3;
3065        __builtin_mma_xxsetaccz(&acc_0);
3066        __builtin_mma_xxsetaccz(&acc_1);
3067        __builtin_mma_xxsetaccz(&acc_2);
3068        __builtin_mma_xxsetaccz(&acc_3);
3069        for (int l = 0; l < k; l+=8) {
3070            packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
3071            packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
3072            for(int x = 0; x < 16; x+=2) {
3073                __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3074                __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
3075                __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
3076                __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
3077            }
3078        }
3079        save_acc(&acc_0, ii, jj);
3080        save_acc(&acc_1, ii, jj + 4);
3081        save_acc(&acc_2, ii + 4, jj);
3082        save_acc(&acc_3, ii + 4, jj + 4);
3083    }
3084
3085    inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
3086        for (int x = 0; x < 16; x += 2) {
3087            __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
3088            __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
3089            __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
3090            __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
3091            __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
3092            __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
3093            __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
3094            __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
3095        }
3096    }
3097
3098    void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
3099        for (int64_t i = 0; i < mc; i += 16) {
3100            int A_base_addr = (mc / 8) * (i / 8) * 16;
3101            for (int64_t j = 0; j < nc; j += 8) {
3102                 int B_base_addr = (nc / 8) * (j / 8) * 16;
3103                 acc_t acc[8];
3104                 vec_t A0_block[16]; vec_t A1_block[16];
3105                 for (int x = 0; x < 8; x++)
3106                     __builtin_mma_xxsetaccz(&acc[x]);
3107                 for (int64_t l = 0; l < kc; l += 8) {
3108                     int A0_block_idx = A_base_addr + (l / 8) * 16;
3109                     int A1_block_idx = A0_block_idx + (mc / 8) * 16;
3110                     int B_block_idx = B_base_addr + (l / 8) * 16;
3111                     vec_t* A0_block = &vec_A[A0_block_idx];
3112                     vec_t* A1_block = &vec_A[A1_block_idx];
3113                     vec_t* B_block = &vec_B[B_block_idx];
3114                     MMA_16x8(A0_block, A1_block, B_block, acc);
3115                 }
3116                 if (kk == 0) {
3117                     save_acc(&acc[0], ii + i, jj + j);
3118                     save_acc(&acc[1], ii + i, jj + j + 4);
3119                     save_acc(&acc[2], ii + i + 4, jj + j);
3120                     save_acc(&acc[3], ii + i + 4, jj + j + 4);
3121                     save_acc(&acc[4], ii + i + 8, jj + j);
3122                     save_acc(&acc[5], ii + i + 8, jj + j + 4);
3123                     save_acc(&acc[6], ii + i + 12, jj + j);
3124                     save_acc(&acc[7], ii + i + 12, jj + j + 4);
3125                 } else {
3126                     add_save_acc(&acc[0], ii + i, jj + j);
3127                     add_save_acc(&acc[1], ii + i, jj + j + 4);
3128                     add_save_acc(&acc[2], ii + i + 4, jj + j);
3129                     add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
3130                     add_save_acc(&acc[4], ii + i + 8, jj + j);
3131                     add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
3132                     add_save_acc(&acc[6], ii + i + 12, jj + j);
3133                     add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
3134                 }
3135            }
3136        }
3137    }
3138
3139    void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3140        int64_t ytiles = m / mc;
3141        int64_t xtiles = n / nc;
3142        int64_t tiles = xtiles * ytiles;
3143        int64_t duty = (tiles + nth - 1) / nth;
3144        int64_t start = duty * ith;
3145        int64_t end = start + duty;
3146        if (end > tiles) {
3147            end = tiles;
3148        }
3149        for (int64_t job = start; job < end; ++job) {
3150            int64_t ii = (job / xtiles) * mc;
3151            int64_t jj = (job % xtiles) * nc;
3152            for (int64_t kk = 0; kk < k; kk += kc) {
3153                 vec_t A_pack[kc * mc / 4];
3154                 vec_t B_pack[kc * nc / 4];
3155                 packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
3156                 packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
3157                 KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
3158            }
3159        }
3160    }
3161
3162    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3163        int m_rem = MIN(m - m0, 8);
3164        int n_rem = MIN(n - n0, 8);
3165        int mc = 0, nc = 0;
3166        if (m_rem >= 8 && n_rem >= 8) {
3167            mc = 8;
3168            nc = 8;
3169            gemm<8, 8>(m0, m, n0, n);
3170        } else if (m_rem >= 4 && n_rem >= 8) {
3171            mc = 4;
3172            nc = 8;
3173            gemm<4, 8>(m0, m, n0, n);
3174        } else if (m_rem >= 8 && n_rem >= 4) {
3175            mc = 8;
3176            nc = 4;
3177            gemm<8, 4>(m0, m, n0, n);
3178        } else if (m_rem >= 4 && n_rem >= 4) {
3179            mc = 4;
3180            nc = 4;
3181            gemm<4, 4>(m0, m, n0, n);
3182        } else {
3183            mc = (m_rem >= 4) ? 4 : m_rem;
3184            nc = (n_rem >= 4) ? 4 : n_rem;
3185            if (mc == 0 || nc == 0)
3186                return;
3187            gemm_small(m0, m, n0, n, mc, nc);
3188        }
3189        int64_t mp = m0 + ((m - m0) / mc) * mc;
3190        int64_t np = n0 + ((n - n0) / nc) * nc;
3191        mnpack(mp, m, n0, np);
3192        mnpack(m0, m, np, n);
3193    }
3194
3195    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3196        int64_t ytiles = (m - m0) / RM;
3197        int64_t xtiles = (n - n0) / RN;
3198        int64_t tiles = xtiles * ytiles;
3199        int64_t duty = (tiles + nth - 1) / nth;
3200        int64_t start = duty * ith;
3201        int64_t end = start + duty;
3202        if (end > tiles)
3203            end = tiles;
3204        for (int64_t job = start; job < end; ++job) {
3205            int64_t ii = m0 + job / xtiles * RM;
3206            int64_t jj = n0 + job % xtiles * RN;
3207            vec_t vec_C[4];
3208            acc_t acc_0;
3209            __builtin_mma_xxsetaccz(&acc_0);
3210            vec_t vec_A[4] = {0}, vec_B[4] = {0};
3211            for (int l = 0; l < k; l += 4) {
3212                /* 'GEMV Forwarding' concept is used in first two conditional loops.
3213                 * when one of the matrix has a single row/column, the elements are
3214                 * broadcasted, instead of using packing routine to prepack the
3215                 * matrix elements.
3216                 */
3217                if (RM == 1) {
3218                    float * a = const_cast<float *>(A + (ii) * lda + l);
3219                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3220                    vec_A[0] = (vec_t)vec_xl(0,a);
3221                    vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
3222                    vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
3223                    vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
3224                } else if (RN == 1) {
3225                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3226                    float * b = const_cast<float *>(B + (jj) * ldb + l);
3227                    vec_B[0] = (vec_t)vec_xl(0,b);
3228                    vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
3229                    vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
3230                    vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
3231                } else {
3232                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3233                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3234                }
3235                __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3236                __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
3237                __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
3238                __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
3239            }
3240            __builtin_mma_disassemble_acc(vec_C, &acc_0);
3241            for (int I = 0; I < RM; I++) {
3242                for (int J = 0; J < RN; J++) {
3243                    *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3244                }
3245            }
3246       }
3247    }
3248
3249    template<int RM, int RN>
3250    inline void kernel(int64_t ii, int64_t jj) {
3251        if constexpr(RM == 4 && RN == 4) {
3252            KERNEL_4x4(ii, jj);
3253        } else if constexpr(RM == 4 && RN == 8) {
3254            KERNEL_4x8(ii, jj);
3255        } else if constexpr(RM == 8 && RN == 4) {
3256            KERNEL_8x4(ii, jj);
3257        } else if constexpr(RM == 8 && RN == 8) {
3258            KERNEL_8x8(ii, jj);
3259        } else {
3260            static_assert(false, "RN/RM values not supported");
3261        }
3262    }
3263
3264    template <int RM, int RN>
3265    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3266        int64_t ytiles = (m - m0) / RM;
3267        int64_t xtiles = (n - n0) / RN;
3268        int64_t tiles = xtiles * ytiles;
3269        int64_t duty = (tiles + nth - 1) / nth;
3270        int64_t start = duty * ith;
3271        int64_t end = start + duty;
3272        if (end > tiles)
3273            end = tiles;
3274        for (int64_t job = start; job < end; ++job) {
3275            int64_t ii = m0 + job / xtiles * RM;
3276            int64_t jj = n0 + job % xtiles * RN;
3277            kernel<RM, RN>(ii, jj);
3278        }
3279    }
3280
3281    const float * const A;
3282    const float * const B;
3283    float * C;
3284    const int64_t k;
3285    const int64_t lda;
3286    const int64_t ldb;
3287    const int64_t ldc;
3288    const int ith;
3289    const int nth;
3290};
3291#endif
3292} // namespace
3293
3294/**
3295 * Performs optimized matrix multiplication on CPU.
3296 *
3297 * This subroutine may compute C = Aᵀ * B with column major ordering.
3298 * Despite its name, this isn't a generalized implementation. Work is
3299 * only performed when a handwritten kernel is written and available.
3300 * Otherwise the caller should fall back to a general matmul routine.
3301 *
3302 * For example, for single-threaded single-precision GEMM you can say
3303 *
3304 *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
3305 *                     0, 1,
3306 *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
3307 *
3308 * @param m is rows in `A` and `C`
3309 * @param n is cols in `B` and `C`
3310 * @param k is cols in `A` and rows in `B`
3311 * @param A is first input matrix (always transposed)
3312 * @param lda is row stride of `A`
3313 * @param B is second input matrix (never transposed)
3314 * @param ldb is row stride of `B`
3315 * @param C is input/output array of output matrices
3316 * @param ldc is row stride of `C`
3317 * @param ith is thread id (must be less than `nth`)
3318 * @param nth is number of threads (must be greater than zero)
3319 * @param Atype is GGML data type of `A`
3320 * @param Btype is GGML data type of `B`
3321 * @param Ctype is GGML data type of `C`
3322 * @return true if this function was able to service the matmul request
3323 */
3324bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
3325                     const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
3326                     int64_t ldc, int Atype, int Btype, int Ctype) {
3327
3328    assert(m >= 0);
3329    assert(n >= 0);
3330    assert(k >= 0);
3331    assert(lda >= k);
3332    assert(ldb >= k);
3333    assert(ldc >= m);
3334    assert(params->nth > 0);
3335    assert(params->ith < params->nth);
3336
3337    // only enable sgemm for prompt processing
3338#if !defined(__MMA__)
3339    if (n < 2)
3340        return false;
3341#endif
3342
3343    if (Ctype != GGML_TYPE_F32)
3344        return false;
3345
3346    switch (Atype) {
3347
3348    case GGML_TYPE_F32: {
3349        if (Btype != GGML_TYPE_F32)
3350            return false;
3351#if defined(__AVX512F__)
3352        tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
3353            k, (const float *)A, lda,
3354            (const float *)B, ldb,
3355            (float *)C, ldc};
3356        return tb.matmul(m, n);
3357#elif defined(__AVX__) || defined(__AVX2__)
3358        tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
3359            k, (const float *)A, lda,
3360            (const float *)B, ldb,
3361            (float *)C, ldc};
3362        return tb.matmul(m, n);
3363#elif defined(__ARM_NEON)
3364        if (n < 4)
3365            return false;
3366        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
3367            k, (const float *)A, lda,
3368            (const float *)B, ldb,
3369            (float *)C, ldc};
3370        return tb.matmul(m, n);
3371#elif defined(__VXE__) || defined(__VXE2__)
3372        if (n < 4)
3373            return false;
3374        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
3375            k, (const float *)A, lda,
3376            (const float *)B, ldb,
3377            (float *)C, ldc};
3378        return tb.matmul(m, n);
3379#elif defined(__MMA__)
3380        if (k % 8)
3381            return false;
3382        tinyBLAS_PPC tb{
3383            k, (const float *)A, lda,
3384            (const float *)B, ldb,
3385            (float *)C, ldc,
3386            params->ith, params->nth};
3387        tb.matmul(m, n);
3388        return true;
3389#elif defined(__riscv_zvfh)
3390    #if LMUL == 1
3391        tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3392            k, (const float *)A, lda,
3393            (const float *)B, ldb,
3394            (float *)C, ldc};
3395    #elif LMUL == 2
3396        tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3397            k, (const float *)A, lda,
3398            (const float *)B, ldb,
3399            (float *)C, ldc};
3400    #else // LMUL = 4
3401        tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3402            k, (const float *)A, lda,
3403            (const float *)B, ldb,
3404            (float *)C, ldc};
3405    #endif
3406        return tb.matmul(m, n);
3407#else
3408        return false;
3409#endif
3410    }
3411
3412    case GGML_TYPE_BF16: {
3413#if defined(__AVX512BF16__)
3414        if (Btype == GGML_TYPE_BF16) {
3415            tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3416                (const ggml_bf16_t *)A, lda,
3417                (const ggml_bf16_t *)B, ldb,
3418                (float *)C, ldc};
3419            return tb.matmul(m, n);
3420        }
3421#elif defined(__AVX512F__)
3422        if (Btype == GGML_TYPE_BF16) {
3423            tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3424                (const ggml_bf16_t *)A, lda,
3425                (const ggml_bf16_t *)B, ldb,
3426                (float *)C, ldc};
3427            return tb.matmul(m, n);
3428        }
3429#elif defined(__AVX2__)
3430        if (Btype == GGML_TYPE_BF16) {
3431            tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3432                (const ggml_bf16_t *)A, lda,
3433                (const ggml_bf16_t *)B, ldb,
3434                (float *)C, ldc};
3435            return tb.matmul(m, n);
3436        }
3437#elif defined(__MMA__)
3438        if (k % 8) {
3439            return false;
3440        }
3441
3442        if (Btype == GGML_TYPE_BF16) {
3443            tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3444                (const ggml_bf16_t *)A, lda,
3445                (const ggml_bf16_t *)B, ldb,
3446                (float *)C, ldc,
3447                params->ith, params->nth };
3448
3449            tb.matmul(m, n);
3450            return true;
3451        }
3452#elif defined(__riscv_zvfbfwma)
3453        #if LMUL == 1
3454            tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3455                k, (const ggml_bf16_t *)A, lda,
3456                (const ggml_bf16_t *)B, ldb,
3457                (float *)C, ldc};
3458        #elif LMUL == 2
3459            tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3460                k, (const ggml_bf16_t *)A, lda,
3461                (const ggml_bf16_t *)B, ldb,
3462                (float *)C, ldc};
3463        #else // LMUL = 4
3464            tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3465                k, (const ggml_bf16_t *)A, lda,
3466                (const ggml_bf16_t *)B, ldb,
3467                (float *)C, ldc};
3468        #endif
3469            return tb.matmul(m, n);
3470#endif
3471        return false;
3472    }
3473
3474    case GGML_TYPE_F16: {
3475#if defined(__AVX512F__)
3476        if (Btype == GGML_TYPE_F16) {
3477            tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
3478                (const ggml_fp16_t *)A, lda,
3479                (const ggml_fp16_t *)B, ldb,
3480                (float *)C, ldc};
3481            return tb.matmul(m, n);
3482        }
3483#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
3484        if (Btype == GGML_TYPE_F16) {
3485            tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
3486                (const ggml_fp16_t *)A, lda,
3487                (const ggml_fp16_t *)B, ldb,
3488                (float *)C, ldc};
3489            return tb.matmul(m, n);
3490        }
3491#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
3492        if (n < 8)
3493            return false;
3494        if (Btype == GGML_TYPE_F16) {
3495            tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3496                k, (const ggml_fp16_t *)A, lda,
3497                (const ggml_fp16_t *)B, ldb,
3498                (float *)C, ldc};
3499            return tb.matmul(m, n);
3500        }
3501#elif defined(__ARM_NEON) && !defined(_MSC_VER)
3502        if (Btype == GGML_TYPE_F32) {
3503            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
3504                k, (const ggml_fp16_t *)A, lda,
3505                (const float *)B, ldb,
3506                (float *)C, ldc};
3507            return tb.matmul(m, n);
3508        }
3509#elif defined(__VXE__) || defined(__VXE2__)
3510        if (n < 4)
3511            return false;
3512        if (Btype == GGML_TYPE_F16) {
3513            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3514                k, (const ggml_fp16_t *)A, lda,
3515                (const ggml_fp16_t *)B, ldb,
3516                (float *)C, ldc};
3517            return tb.matmul(m, n);
3518        }
3519#elif defined(__riscv_zvfh)
3520        if (Btype == GGML_TYPE_F16) {
3521        #if LMUL == 1
3522            tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3523                k, (const ggml_fp16_t *)A, lda,
3524                (const ggml_fp16_t *)B, ldb,
3525                (float *)C, ldc};
3526        #elif LMUL == 2
3527            tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3528                k, (const ggml_fp16_t *)A, lda,
3529                (const ggml_fp16_t *)B, ldb,
3530                (float *)C, ldc};
3531        #else // LMUL = 4
3532            tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3533                k, (const ggml_fp16_t *)A, lda,
3534                (const ggml_fp16_t *)B, ldb,
3535                (float *)C, ldc};
3536        #endif
3537            return tb.matmul(m, n);
3538        }
3539#elif defined(__MMA__)
3540        if (k % 8) {
3541            return false;
3542        }
3543
3544        if (Btype == GGML_TYPE_F16) {
3545            tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
3546                (const ggml_fp16_t *)A, lda,
3547                (const ggml_fp16_t *)B, ldb,
3548                (float *)C, ldc,
3549                params->ith, params->nth };
3550
3551            tb.matmul(m, n);
3552            return true;
3553        }
3554#endif
3555        return false;
3556    }
3557
3558    case GGML_TYPE_Q8_0: {
3559        if (Btype != GGML_TYPE_Q8_0)
3560           return false;
3561#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3562        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
3563            k, (const block_q8_0 *)A, lda,
3564            (const block_q8_0 *)B, ldb,
3565            (float *)C, ldc,
3566            params->ith, params->nth};
3567        tb.matmul(m, n);
3568        return true;
3569#elif defined(__ARM_FEATURE_DOTPROD)
3570        tinyBLAS_Q0_ARM<block_q8_0> tb{
3571            k, (const block_q8_0 *)A, lda,
3572            (const block_q8_0 *)B, ldb,
3573            (float *)C, ldc,
3574            params->ith, params->nth};
3575        tb.matmul(m, n);
3576        return true;
3577#elif defined(__MMA__)
3578    //TO-DO: Remove this condition once gemv forwarding is enabled.
3579        if (n < 8 && n != 4)
3580           return false;
3581        if (m < 8 && m != 4)
3582           return false;
3583        tinyBLAS_Q0_PPC<block_q8_0> tb{
3584            k, (const block_q8_0 *)A, lda,
3585            (const block_q8_0 *)B, ldb,
3586            (float *)C, ldc,
3587            params->ith, params->nth};
3588        tb.matmul(m, n);
3589        return true;
3590#else
3591        return false;
3592#endif
3593    }
3594
3595    case GGML_TYPE_Q4_0: {
3596        if (Btype != GGML_TYPE_Q8_0)
3597            return false;
3598#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3599        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
3600            k, (const block_q4_0 *)A, lda,
3601            (const block_q8_0 *)B, ldb,
3602            (float *)C, ldc,
3603            params->ith, params->nth};
3604        tb.matmul(m, n);
3605        return true;
3606#elif defined(__ARM_FEATURE_DOTPROD)
3607        tinyBLAS_Q0_ARM<block_q4_0> tb{
3608            k, (const block_q4_0 *)A, lda,
3609            (const block_q8_0 *)B, ldb,
3610            (float *)C, ldc,
3611            params->ith, params->nth};
3612        tb.matmul(m, n);
3613        return true;
3614#elif defined(__MMA__)
3615    //TO-DO: Remove this condition once gemv forwarding is enabled.
3616        if (n < 8 && n != 4)
3617           return false;
3618        if (m < 8 && m != 4)
3619           return false;
3620        tinyBLAS_Q0_PPC<block_q4_0> tb{
3621            k, (const block_q4_0 *)A, lda,
3622            (const block_q8_0 *)B, ldb,
3623            (float *)C, ldc,
3624            params->ith, params->nth};
3625        tb.matmul(m, n);
3626        return true;
3627#else
3628        return false;
3629#endif
3630    }
3631
3632    case GGML_TYPE_Q5_0: {
3633        if (Btype != GGML_TYPE_Q8_0)
3634            return false;
3635#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3636        tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
3637            k, (const block_q5_0 *)A, lda,
3638            (const block_q8_0 *)B, ldb,
3639            (float *)C, ldc,
3640            params->ith, params->nth};
3641        tb.matmul(m, n);
3642        return true;
3643#else
3644        return false;
3645#endif
3646    }
3647
3648    case GGML_TYPE_IQ4_NL: {
3649        if (Btype != GGML_TYPE_Q8_0)
3650            return false;
3651#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3652        tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
3653            k, (const block_iq4_nl *)A, lda,
3654            (const block_q8_0 *)B, ldb,
3655            (float *)C, ldc,
3656            params->ith, params->nth};
3657        tb.matmul(m, n);
3658        return true;
3659#else
3660        return false;
3661#endif
3662    }
3663
3664    default:
3665        return false;
3666    }
3667
3668    (void)params;
3669    (void)m;
3670    (void)n;
3671    (void)k;
3672    (void)A;
3673    (void)lda;
3674    (void)B;
3675    (void)ldb;
3676    (void)C;
3677    (void)ldc;
3678    (void)Atype;
3679    (void)Btype;
3680    (void)Ctype;
3681}